Source code for caproto.asyncio.server

import asyncio
import functools
import socket
import sys

import caproto as ca

from ..server import AsyncLibraryLayer
from ..server.common import Context as _Context
from ..server.common import VirtualCircuit as _VirtualCircuit
from .utils import (AsyncioQueue, _DatagramProtocol, _TaskHandler,
                    _TransportWrapper)


class ServerExit(Exception):
    ...


class Event(asyncio.Event):
    "Implement the ``timeout`` keyword to wait(), as in threading.Event."
    async def wait(self, timeout=None):
        try:
            await asyncio.wait_for(super().wait(), timeout)
        except asyncio.TimeoutError:  # somehow not just a TimeoutError...
            pass
        return self.is_set()


class AsyncioAsyncLayer(AsyncLibraryLayer):
    name = 'asyncio'
    Event = asyncio.Event
    library = asyncio
    ThreadsafeQueue = AsyncioQueue


class VirtualCircuit(_VirtualCircuit):
    "Wraps a caproto.VirtualCircuit with an asyncio client."
    TaskCancelled = asyncio.CancelledError

    def __init__(self, circuit, client, context, *, loop=None):
        if loop is None:
            loop = asyncio.get_event_loop()
        self.loop = loop

        class SockWrapper:
            def __init__(self, loop, client):
                self.loop = loop
                self.client = client

            def getsockname(self):
                return self.client.getsockname()

            async def recv(self, nbytes):
                return (await self.loop.sock_recv(self.client, nbytes))

        self._raw_lock = asyncio.Lock()
        self._raw_client = client
        super().__init__(circuit, SockWrapper(loop, client), context)
        self.QueueFull = asyncio.QueueFull
        self.command_queue = asyncio.Queue(ca.MAX_COMMAND_BACKLOG,
                                           loop=self.loop)
        self.new_command_condition = asyncio.Condition(loop=self.loop)
        self.events_on = asyncio.Event(loop=self.loop)
        self.subscription_queue = asyncio.Queue(
            ca.MAX_TOTAL_SUBSCRIPTION_BACKLOG, loop=self.loop)
        self.write_event = Event(loop=self.loop)
        self.tasks = _TaskHandler()
        self._sub_task = None

    async def get_from_sub_queue(self, timeout=None):
        # Timeouts work very differently between our server implementations,
        # so we do this little stub in its own method.
        fut = asyncio.ensure_future(self.subscription_queue.get())
        try:
            return await asyncio.wait_for(fut, timeout, loop=self.loop)
        except asyncio.TimeoutError:
            return None

    async def send(self, *commands):
        if self.connected:
            buffers_to_send = self.circuit.send(*commands)
            # lock to make sure a AddEvent does not write bytes
            # to the socket while we are sending
            async with self._raw_lock:
                await self.loop.sock_sendall(self._raw_client,
                                             b''.join(buffers_to_send))

    async def run(self):
        self.tasks.create(self.command_queue_loop())
        self._sub_task = self.tasks.create(self.subscription_queue_loop())

    async def _start_write_task(self, handle_write):
        self.tasks.create(handle_write())

    async def _wake_new_command(self):
        async with self.new_command_condition:
            self.new_command_condition.notify_all()

    async def _on_disconnect(self):
        await super()._on_disconnect()
        self._raw_client.close()
        if self._sub_task is not None:
            await self.tasks.cancel(self._sub_task)
            self._sub_task = None


class Context(_Context):
    CircuitClass = VirtualCircuit
    async_layer = None
    ServerExit = ServerExit
    TaskCancelled = asyncio.CancelledError

    def __init__(self, pvdb, interfaces=None, *, loop=None):
        super().__init__(pvdb, interfaces)
        self.command_bundle_queue = asyncio.Queue()
        self.subscription_queue = asyncio.Queue()
        if loop is None:
            loop = asyncio.get_event_loop()
        self.loop = loop
        self.async_layer = AsyncioAsyncLayer()
        self.server_tasks = _TaskHandler()
        self.tcp_sockets = dict()

    async def server_accept_loop(self, sock):
        sock.listen()

        while True:
            client_sock, addr = await self.loop.sock_accept(sock)
            self.server_tasks.create(self.tcp_handler(client_sock, addr))

    async def run(self, *, log_pv_names=False):
        'Start the server'
        self.log.info('Asyncio server starting up...')

        async def make_socket(interface, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            s.setblocking(False)
            s.bind((interface, port))
            return s

        self.port, self.tcp_sockets = await self._bind_tcp_sockets_with_consistent_port_number(
            make_socket)
        tasks = _TaskHandler()
        for interface, sock in self.tcp_sockets.items():
            self.log.info("Listening on %s:%d", interface, self.port)
            self.broadcaster.server_addresses.append((interface, self.port))
            tasks.create(self.server_accept_loop(sock))

        class ConnectedTransportWrapper:
            """Make an asyncio transport something you can call send on."""
            def __init__(self, transport, address):
                self.transport = transport
                self.address = address

            async def send(self, bytes_to_send):
                try:
                    self.transport.sendto(bytes_to_send, self.address)
                except OSError as exc:
                    host, port = self.address
                    raise ca.CaprotoNetworkError(
                        f"Failed to send to {host}:{port}") from exc

            def close(self):
                return self.transport.close()

        reuse_port = sys.platform not in ('win32', ) and hasattr(socket, 'SO_REUSEPORT')
        for address in ca.get_beacon_address_list():
            transport, _ = await self.loop.create_datagram_endpoint(
                functools.partial(_DatagramProtocol, parent=self,
                                  recv_func=self._datagram_received),
                remote_addr=address, allow_broadcast=True,
                reuse_port=reuse_port)
            wrapped_transport = ConnectedTransportWrapper(transport, address)
            self.beacon_socks[address] = (interface,   # TODO; this is incorrect
                                          wrapped_transport)

        for interface in self.interfaces:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
            # Python says this is unsafe, but we need it to have
            # multiple servers live on the same host.
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if reuse_port:
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
            sock.setblocking(False)
            sock.bind((interface, self.ca_server_port))

            transport, _ = await self.loop.create_datagram_endpoint(
                functools.partial(_DatagramProtocol, parent=self,
                                  recv_func=self._datagram_received),
                sock=sock)
            self.udp_socks[interface] = _TransportWrapper(transport)
            self.log.debug('UDP socket bound on %s:%d', interface,
                           self.ca_server_port)

        tasks.create(self.broadcaster_queue_loop())
        tasks.create(self.subscription_queue_loop())
        tasks.create(self.broadcast_beacon_loop())

        async_lib = AsyncioAsyncLayer()
        for name, method in self.startup_methods.items():
            self.log.debug('Calling startup method %r', name)
            tasks.create(method(async_lib))
        self.log.info('Server startup complete.')
        if log_pv_names:
            self.log.info('PVs available:\n%s', '\n'.join(self.pvdb))

        try:
            await asyncio.gather(*tasks.tasks)
        except asyncio.CancelledError:
            self.log.info('Server task cancelled. Will shut down.')
            await tasks.cancel_all()
            await self.server_tasks.cancel_all()
            for circuit in self.circuits:
                await circuit.tasks.cancel_all()
            return
        except Exception:
            self.log.exception('Server error. Will shut down')
            raise
        finally:
            self.log.info('Server exiting....')
            shutdown_tasks = []
            async_lib = AsyncioAsyncLayer()
            for name, method in self.shutdown_methods.items():
                self.log.debug('Calling shutdown method %r', name)
                task = self.loop.create_task(method(async_lib))
                shutdown_tasks.append(task)
            await asyncio.gather(*shutdown_tasks)
            for sock in self.tcp_sockets.values():
                sock.close()
            for sock in self.udp_socks.values():
                sock.close()
            for _interface, sock in self.beacon_socks.values():
                sock.close()

    def _datagram_received(self, pair):
        bytes_received, address = pair
        try:
            commands = self.broadcaster.recv(bytes_received, address)
        except ca.RemoteProtocolError:
            self.log.exception('Broadcaster received bad packet')
        else:
            self.command_bundle_queue.put_nowait((address, commands))


async def start_server(pvdb, *, interfaces=None, log_pv_names=False):
    '''Start an asyncio server with a given PV database'''
    ctx = Context(pvdb, interfaces)
    ret = await ctx.run(log_pv_names=log_pv_names)
    return ret


def run(pvdb, *, interfaces=None, log_pv_names=False):
    """
    A synchronous function that wraps start_server and exits cleanly.
    """
    loop = asyncio.get_event_loop()
    task = loop.create_task(
        start_server(pvdb, interfaces=interfaces, log_pv_names=log_pv_names))
    try:
        loop.run_until_complete(task)
    except KeyboardInterrupt:
        ...
    finally:
        task.cancel()
        loop.run_until_complete(task)