Source code for caproto.asyncio.server

import asyncio
import functools
import sys

import caproto as ca

from .._utils import CaprotoNetworkError
from ..server import AsyncLibraryLayer
from ..server.common import Context as _Context
from ..server.common import DisconnectedCircuit
from ..server.common import VirtualCircuit as _VirtualCircuit
from .utils import (AsyncioQueue, _create_bound_tcp_socket, _create_udp_socket,
                    _DatagramProtocol, _TaskHandler, _TransportWrapper,
                    _UdpTransportWrapper)


[docs]class ServerExit(Exception): ...
[docs]class Event(asyncio.Event): "Implement the ``timeout`` keyword to wait(), as in threading.Event." async def wait(self, timeout=None): try: await asyncio.wait_for(super().wait(), timeout) except asyncio.TimeoutError: # somehow not just a TimeoutError... pass return self.is_set()
[docs]class AsyncioAsyncLayer(AsyncLibraryLayer): name = 'asyncio' Event = asyncio.Event library = asyncio sleep = staticmethod(asyncio.sleep) ThreadsafeQueue = AsyncioQueue
[docs]class VirtualCircuit(_VirtualCircuit): "Wraps a caproto.VirtualCircuit with an asyncio client." TaskCancelled = asyncio.CancelledError def __init__(self, circuit, client, context, *, loop=None): if loop is None: loop = asyncio.get_event_loop() self.loop = loop super().__init__(circuit, client, context) self.QueueFull = asyncio.QueueFull self.command_queue = asyncio.Queue(ca.MAX_COMMAND_BACKLOG, loop=self.loop) self.new_command_condition = asyncio.Condition(loop=self.loop) self.events_on = asyncio.Event(loop=self.loop) self.subscription_queue = asyncio.Queue( ca.MAX_TOTAL_SUBSCRIPTION_BACKLOG, loop=self.loop) self.write_event = Event(loop=self.loop) self.tasks = _TaskHandler() self._sub_task = None async def get_from_sub_queue(self, timeout=None): # Timeouts work very differently between our server implementations, # so we do this little stub in its own method. fut = asyncio.ensure_future(self.subscription_queue.get()) try: return await asyncio.wait_for(fut, timeout, loop=self.loop) except asyncio.TimeoutError: return None async def send(self, *commands): if self.connected: buffers_to_send = self.circuit.send(*commands) try: await self.client.send(b''.join(buffers_to_send)) except CaprotoNetworkError as ex: raise DisconnectedCircuit( f"Circuit disconnected: {ex}" ) from ex async def run(self): self.tasks.create(self.command_queue_loop()) self._sub_task = self.tasks.create(self.subscription_queue_loop()) async def _start_write_task(self, handle_write): self.tasks.create(handle_write()) async def _wake_new_command(self): async with self.new_command_condition: self.new_command_condition.notify_all() async def _on_disconnect(self): await super()._on_disconnect() self.client.close() if self._sub_task is not None: await self.tasks.cancel(self._sub_task) self._sub_task = None
[docs]class Context(_Context): CircuitClass = VirtualCircuit async_layer = None ServerExit = ServerExit TaskCancelled = asyncio.CancelledError def __init__(self, pvdb, interfaces=None, *, loop=None): super().__init__(pvdb, interfaces) self.broadcaster_datagram_queue = AsyncioQueue( ca.MAX_COMMAND_BACKLOG ) self.command_bundle_queue = asyncio.Queue( ca.MAX_COMMAND_BACKLOG ) self.subscription_queue = asyncio.Queue() if loop is None: loop = asyncio.get_event_loop() self.loop = loop self.async_layer = AsyncioAsyncLayer() self.server_tasks = _TaskHandler() self.tcp_sockets = dict() async def server_accept_loop(self, sock): """Start a TCP server on `sock` and listen for new connections.""" def _new_client(reader, writer): transport = _TransportWrapper(reader, writer) self.server_tasks.create( self.tcp_handler(transport, transport.getpeername()) ) # TODO: when Python 3.7 is the minimum version, the following server # can be an async context manager: await asyncio.start_server( _new_client, sock=sock, ) async def run(self, *, log_pv_names=False, startup_hook=None): 'Start the server' self.log.info('Asyncio server starting up...') self.port, self.tcp_sockets = await self._bind_tcp_sockets_with_consistent_port_number( _create_bound_tcp_socket ) tasks = _TaskHandler() for interface, sock in self.tcp_sockets.items(): self.log.info("Listening on %s:%d", interface, self.port) self.broadcaster.server_addresses.append((interface, self.port)) tasks.create(self.server_accept_loop(sock)) for address in ca.get_beacon_address_list(): sock = _create_udp_socket() try: sock.connect(address) except Exception as ex: self.log.error( 'Beacon (%s:%d) socket setup failed: %s', *address, ex, ) continue wrapped_transport = _UdpTransportWrapper( sock, address, loop=self.loop ) self.beacon_socks[address] = (interface, # TODO; this is incorrect wrapped_transport) for interface in self.interfaces: await self._create_broadcaster_transport(interface) tasks.create(self.broadcaster_receive_loop()) tasks.create(self.broadcaster_queue_loop()) tasks.create(self.subscription_queue_loop()) tasks.create(self.broadcast_beacon_loop()) async_lib = AsyncioAsyncLayer() if startup_hook is not None: self.log.debug('Calling startup hook %r', startup_hook.__name__) tasks.create(startup_hook(async_lib)) for name, method in self.startup_methods.items(): self.log.debug('Calling startup method %r', name) tasks.create(method(async_lib)) self.log.info('Server startup complete.') if log_pv_names: self.log.info('PVs available:\n%s', '\n'.join(self.pvdb)) try: await asyncio.gather(*tasks.tasks) except asyncio.CancelledError: self.log.info('Server task cancelled. Will shut down.') await tasks.cancel_all() await self.server_tasks.cancel_all() for circuit in self.circuits: await circuit.tasks.cancel_all() return except Exception: self.log.exception('Server error. Will shut down') raise finally: self.log.info('Server exiting....') shutdown_tasks = [] async_lib = AsyncioAsyncLayer() for name, method in self.shutdown_methods.items(): self.log.debug('Calling shutdown method %r', name) task = self.loop.create_task(method(async_lib)) shutdown_tasks.append(task) await asyncio.gather(*shutdown_tasks) for sock in self.tcp_sockets.values(): sock.close() for sock in self.udp_socks.values(): sock.close() for _, sock in self.beacon_socks.values(): sock.close() async def _create_broadcaster_transport(self, interface): """Create broadcaster transport on the given interface.""" old_transport = self.udp_socks.pop(interface, None) if old_transport is not None: try: old_transport.close() except OSError: self.log.warning( "Failed to close old transport for interface %s", interface ) sock = _create_udp_socket() sock.bind((interface, self.ca_server_port)) transport, _ = await self.loop.create_datagram_endpoint( functools.partial(_DatagramProtocol, parent=self, identifier=interface, queue=self.broadcaster_datagram_queue), sock=sock, ) self.udp_socks[interface] = _UdpTransportWrapper( transport, loop=self.loop ) self.log.debug('UDP socket bound on %s:%d', interface, self.ca_server_port) async def broadcaster_receive_loop(self): # UdpTransport -> broadcaster_datagram_queue -> command_bundle_queue queue = self.broadcaster_datagram_queue while True: interface, data, address = await queue.async_get() if isinstance(data, Exception): self.log.exception('Broadcaster failed to receive on %s', interface, exc_info=data) if sys.platform == 'win32': self.log.warning( 'Re-initializing socket on interface %s', interface ) else: await self._broadcaster_recv_datagram(data, address)
[docs]async def start_server(pvdb, *, interfaces=None, log_pv_names=False, startup_hook=None): '''Start an asyncio server with a given PV database''' ctx = Context(pvdb, interfaces) ret = await ctx.run(log_pv_names=log_pv_names, startup_hook=startup_hook) return ret
[docs]def run(pvdb, *, interfaces=None, log_pv_names=False, startup_hook=None): """ Run an IOC, given its PV database dictionary. A synchronous function that wraps start_server and exits cleanly. Parameters ---------- pvdb : dict The PV database. interfaces : list, optional List of interfaces to listen on. log_pv_names : bool, optional Log PV names at startup. startup_hook : coroutine, optional Hook to call at startup with the ``async_lib`` shim. """ loop = asyncio.get_event_loop() task = loop.create_task( start_server(pvdb, interfaces=interfaces, log_pv_names=log_pv_names, startup_hook=startup_hook)) try: loop.run_until_complete(task) except KeyboardInterrupt: ... finally: task.cancel() loop.run_until_complete(task)