Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Asyncio how to reuse a socket

How would I reuse a socket to a server in asyncio? Instead of creating a new connection for each query?

Here is my code;

async def lookup(server, port, query, sema):
    async with sema as sema:
        try:
            reader, writer = await asyncio.open_connection(server, port)
        except:
            return {}
        writer.write(query.encode("ISO-8859-1"))
        await writer.drain()
        data = b""
        while True:
            d = await reader.read(4096)
            if not d:
                break
            data += d
        writer.close()
        data = data.decode("ISO-8859-1")
        return data
like image 287
Jonathan Avatar asked Oct 14 '25 18:10

Jonathan


1 Answers

You'd simply call the asyncio.open_connection(server, port) coroutine just once, and keep using the reader and writer (provided the server doesn't just close the connection on their end, of course).

I'd do so in a separate async context manager object for your connections, and use a connection pool to manage the connections, so you can create and re-use socket connections for many concurrent tasks. By using an (async) context manager, Python makes sure to notify the connection when your code is done with it, so the connection can be released back to the pool:

import asyncio
import contextlib

from collections import OrderedDict
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type


try:  # Python 3.7
    base = contextlib.AbstractAsyncContextManager
except AttributeError:
    base = object  # type: ignore

Server = str
Port = int
Host = Tuple[Server, Port]


class ConnectionPool(base):
    def __init__(
        self,
        max_connections: int = 1000,
        loop: Optional[asyncio.AbstractEventLoop] = None,
    ):
        self.max_connections = max_connections
        self._loop = loop or asyncio.get_event_loop()

        self._connections: OrderedDict[Host, List["Connection"]] = OrderedDict()
        self._semaphore = asyncio.Semaphore(max_connections)

    async def connect(self, server: Server, port: Port) -> "Connection":
        host = (server, port)

        # enforce the connection limit, releasing connections notifies
        # the semaphore to release here
        await self._semaphore.acquire()

        connections = self._connections.setdefault(host, [])
        # find an un-used connection for this host
        connection = next((conn for conn in connections if not conn.in_use), None)
        if connection is None:
            # disconnect the least-recently-used un-used connection to make space
            # for a new connection. There will be at least one.
            for conns_per_host in reversed(self._connections.values()):
                for conn in conns_per_host:
                    if not conn.in_use:
                        await conn.close()
                        break

            reader, writer = await asyncio.open_connection(server, port)
            connection = Connection(self, host, reader, writer)
            connections.append(connection)

        connection.in_use = True
        # move current host to the front as most-recently used
        self._connections.move_to_end(host, False)

        return connection

    async def close(self):
        """Close all connections"""
        connections = [c for cs in self._connections.values() for c in cs]
        self._connections = OrderedDict()
        for connection in connections:
            await connection.close()

    def _remove(self, connection):
        conns_for_host = self._connections.get(connection._host)
        if not conns_for_host:
            return
        conns_for_host[:] = [c for c in conns_for_host if c != connection]

    def _notify_release(self):
        self._semaphore.release()

    async def __aenter__(self) -> "ConnectionPool":
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType],
    ) -> None:
        await self.close()

    def __del__(self) -> None:
        connections = [repr(c) for cs in self._connections.values() for c in cs]
        if not connections:
            return

        context = {
            "pool": self,
            "connections": connections,
            "message": "Unclosed connection pool",
        }
        self._loop.call_exception_handler(context)


class Connection(base):
    def __init__(
        self,
        pool: ConnectionPool,
        host: Host,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
    ):
        self._host = host
        self._pool = pool
        self._reader = reader
        self._writer = writer
        self._closed = False
        self.in_use = False

    def __repr__(self):
        host = f"{self._host[0]}:{self._host[1]}"
        return f"Connection<{host}>"

    @property
    def closed(self):
        return self._closed

    def release(self) -> None:
        self.in_use = False
        self._pool._notify_release()

    async def close(self) -> None:
        if self._closed:
            return
        self._closed = True
        self._writer.close()
        self._pool._remove(self)
        try:
            await self._writer.wait_closed()
        except AttributeError:  # wait_closed is new in 3.7
            pass

    def __getattr__(self, name: str) -> Any:
        """All unknown attributes are delegated to the reader and writer"""
        if self._closed or not self.in_use:
            raise ValueError("Can't use a closed or unacquired connection")
        if hasattr(self._reader, name):
            return getattr(self._reader, name)
        return getattr(self._writer, name)

    async def __aenter__(self) -> "Connection":
        if self._closed or not self.in_use:
            raise ValueError("Can't use a closed or unacquired connection")
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType],
    ) -> None:
        self.release()

    def __del__(self) -> None:
        if self._closed:
            return
        context = {"connection": self, "message": "Unclosed connection"}
        self._pool._loop.call_exception_handler(context)

then pass in a pool object to your lookup coroutine; the connection object produced proxies for both the reader and writer parts:

async def lookup(pool, server, port, query):
    try:
        conn = await pool.connect(server, port)
    except (ValueError, OSError):
        return {}

    async with conn:
        conn.write(query.encode("ISO-8859-1"))
        await conn.drain()
        data = b""
        while True:
            d = await conn.read(4096)
            if not d:
                break
            data += d
        data = data.decode("ISO-8859-1")
        return data

Note that the standard WHOIS protocol (RFC 3912 or predecessors states that the connection is closed after every query. If you are connecting to a standard WHOIS service on port 43, there is no point in re-using sockets.

What happens in this case is that the reader will have reached EOF (reader.at_eof() is true), and any further attempts at reading will simply return nothing (reader.read(...) will always return an empty b'' value). Writing to the writer is not going to be an error until the socket connection is terminated by the remote side after a time-out. You can write all you want to the connection, but the WHOIS server will just ignore the queries.

like image 129
Martijn Pieters Avatar answered Oct 17 '25 07:10

Martijn Pieters