Source code for caproto.trio.server

import functools
import threading

import trio
from trio import socket

import caproto as ca

from .._utils import safe_getsockname
from ..server import AsyncLibraryLayer
from ..server.common import Context as _Context
from ..server.common import DisconnectedCircuit, LoopExit
from ..server.common import VirtualCircuit as _VirtualCircuit
from .util import open_memory_channel


[docs]class ServerExit(Exception): ...
[docs]class Event: """ The class wraps `trio.Event` and implements `Event.clear()` method that is used in the code, but deprecated in `trio`. The event is cleared by creating the new event that if the event is set. If the event is not set, then old event is left. The class is not intended to be thread safe. """ def __init__(self): self._create_event() def _create_event(self): self._event = trio.Event() def set(self): self._event.set() def clear(self): """ Clearing of the event is implemented by creating a new event if the event is set. It is assumed that if the event is set, then it is not waited on. """ if self._event.is_set(): self._create_event() def is_set(self): return self._event.is_set() async def wait(self, timeout=None): if timeout is not None: with trio.move_on_after(timeout): await self._event.wait() return True return False else: await self._event.wait() return True
def _universal_queue(max_len=1000): class UniversalQueue: def __init__(self): self._send, self._recv = open_memory_channel(max_len) def put(self, value): trio.from_thread.run(self._send.send, value) async def async_put(self, value): await self._send.send(value) def get(self): return trio.from_thread.run(self._recv.receive) async def async_get(self): return await self._recv.receive() return UniversalQueue
[docs]class TrioAsyncLayer(AsyncLibraryLayer): def __init__(self): self.ThreadsafeQueue = _universal_queue() name = 'trio' ThreadsafeQueue = None Event = Event library = trio sleep = staticmethod(trio.sleep)
class TrioSocketStreamCompat: _sock: socket.socket # trio._socket._SocketType _stream: trio.SocketStream _token: trio.lowlevel.TrioToken _send_lock: trio.StrictFIFOLock def __init__(self, sock: socket.socket, token: trio.lowlevel.TrioToken): self._sock = sock self._stream = trio.SocketStream(sock) self._token = token self._send_lock = trio.StrictFIFOLock() async def recv(self, max_bytes=None): return await self._stream.receive_some(max_bytes) def close(self, *args, **kwargs): def _close(): trio.from_thread.run(self._stream.aclose, trio_token=self._token) threading.Thread(target=_close, daemon=True).start() def getsockname(self, *args, **kwargs): return self._sock.getsockname() async def send_all(self, data): try: async with self._send_lock: return await self._stream.send_all(data) except trio.BrokenResourceError: raise DisconnectedCircuit("Disconnected while sending to client")
[docs]class VirtualCircuit(_VirtualCircuit): "Wraps a caproto.VirtualCircuit with a trio client." TaskCancelled = trio.Cancelled def __init__(self, circuit, client: TrioSocketStreamCompat, context): super().__init__(circuit, client, context) self.nursery = context.nursery self.QueueFull = trio.WouldBlock self.command_chan = open_memory_channel(ca.MAX_COMMAND_BACKLOG) # For compatibility with server common: self.command_queue = self.command_chan.send self.new_command_condition = trio.Condition() self.subscription_chan = open_memory_channel(ca.MAX_TOTAL_SUBSCRIPTION_BACKLOG) # For compatibility with server common: self.subscription_queue = self.subscription_chan.send self.write_event = Event() self.events_on = trio.Event() async def run(self): await self.nursery.start(self.command_queue_loop) await self.nursery.start(self.subscription_queue_loop) async def _send_buffers(self, *buffers): """Send ``buffers`` over the wire.""" await self.client.send_all(b"".join(buffers)) async def command_queue_loop(self, task_status): self.write_event.set() send, recv = self.command_chan async with send: async with recv: task_status.started() try: async for command in recv: response = await self._command_queue_iteration(command) if response is not None: await self.send(*response) await self._wake_new_command() except DisconnectedCircuit: await self._on_disconnect() self.circuit.disconnect() await self.context.circuit_disconnected(self) except trio.Cancelled: ... except LoopExit: ... async def subscription_queue_loop(self, task_status): task_status.started() await super().subscription_queue_loop() async def get_from_sub_queue(self, timeout=None): # Timeouts work very differently between our server implementations, # so we do this little stub in its own method. _, recv = self.subscription_chan if timeout is None: return await recv.receive() with trio.move_on_after(timeout): return await recv.receive() async def _on_disconnect(self): """Executed when disconnection detected""" await super()._on_disconnect() self.client.close() async def _start_write_task(self, handle_write): self.nursery.start_soon(handle_write) async def _wake_new_command(self): async with self.new_command_condition: self.new_command_condition.notify_all()
[docs]class Context(_Context): CircuitClass = VirtualCircuit async_layer = TrioAsyncLayer ServerExit = ServerExit TaskCancelled = trio.Cancelled def __init__(self, pvdb, interfaces=None): super().__init__(pvdb, interfaces) self.nursery = None self.command_chan = open_memory_channel(ca.MAX_COMMAND_BACKLOG) self.command_bundle_queue = self.command_chan.send self.subscription_chan = open_memory_channel(ca.MAX_TOTAL_SUBSCRIPTION_BACKLOG) self.subscription_queue = self.subscription_chan.send self.beacon_sock = ca.bcast_socket(socket) async def broadcaster_udp_server_loop(self, task_status): for interface in self.interfaces: udp_sock = ca.bcast_socket(socket) try: await udp_sock.bind((interface, self.ca_server_port)) except Exception: self.log.exception('UDP bind failure on interface %r', interface) raise self.broadcaster.server_addresses.append( safe_getsockname(udp_sock) ) self.log.debug('UDP socket bound on %s:%d', interface, self.ca_server_port) self.udp_socks[interface] = udp_sock for interface, udp_sock in self.udp_socks.items(): self.log.debug('Broadcasting on %s:%d', interface, self.ca_server_port) self.nursery.start_soon(self._core_broadcaster_loop, udp_sock) task_status.started() async def broadcaster_queue_loop(self, task_status): send, recv = self.command_chan async with send: async with recv: try: task_status.started() async for addr, commands in recv: await super()._broadcaster_queue_iteration(addr, commands) except trio.Cancelled: ... except LoopExit: ... async def subscription_queue_loop(self, task_status): send, recv = self.subscription_chan async with send: async with recv: task_status.started() async for item in recv: sub_specs, metadata, values, flags, sub = item await self._subscription_queue_iteration( sub_specs, metadata, values, flags, sub) async def broadcast_beacon_loop(self, task_status): task_status.started() await super().broadcast_beacon_loop() async def server_accept_loop(self, listen_sock, *, task_status): try: listen_sock.listen() task_status.started() while True: trio_sock, addr = await listen_sock.accept() stream = TrioSocketStreamCompat( sock=trio_sock, token=trio.lowlevel.current_trio_token() ) self.nursery.start_soon( self.tcp_handler, stream, addr ) finally: listen_sock.close() async def run(self, *, log_pv_names=False, startup_hook=None): 'Start the server' self.log.info('Trio server starting up...') try: async with trio.open_nursery() as self.nursery: for address in ca.get_beacon_address_list(): sock = ca.bcast_socket(socket) await sock.connect(address) interface, _ = safe_getsockname(sock) self.beacon_socks[address] = (interface, sock) async def make_socket(interface, port): s = trio.socket.socket() await s.bind((interface, port)) return s res = await self._bind_tcp_sockets_with_consistent_port_number( make_socket) self.port, self.tcp_sockets = res await self.nursery.start(self.broadcaster_udp_server_loop) await self.nursery.start(self.broadcaster_queue_loop) await self.nursery.start(self.subscription_queue_loop) await self.nursery.start(self.broadcast_beacon_loop) # Only after all loops have been started, begin listening: for interface, listen_sock in self.tcp_sockets.items(): self.log.info("Listening on %s:%d", interface, self.port) await self.nursery.start(self.server_accept_loop, listen_sock) async_lib = TrioAsyncLayer() if startup_hook is not None: self.log.debug('Calling startup hook %r', startup_hook.__name__) async def _startup(task_status): task_status.started() await startup_hook(async_lib) await self.nursery.start(_startup) for name, method in self.startup_methods.items(): self.log.debug('Calling startup method %r', name) async def _startup(task_status): task_status.started() await method(async_lib) await self.nursery.start(_startup) self.log.info('Server startup complete.') if log_pv_names: self.log.info('PVs available:\n%s', '\n'.join(self.pvdb)) except trio.Cancelled: self.log.info('Server task cancelled. Will shut down.') finally: self.log.info('Server exiting....') async_lib = TrioAsyncLayer() async with trio.open_nursery() as nursery: for name, method in self.shutdown_methods.items(): self.log.debug('Calling shutdown method %r', name) async def shutdown(task_status): task_status.started() await method(async_lib) await nursery.start(shutdown) for sock in self.tcp_sockets.values(): sock.close() for sock in self.udp_socks.values(): sock.close() for _interface, sock in self.beacon_socks.values(): sock.close() def stop(self): 'Stop the server' nursery = self.nursery if nursery is None: return nursery.cancel_scope.cancel()
[docs]async def start_server(pvdb, *, interfaces=None, log_pv_names=False, startup_hook=None): '''Start a trio server with a given PV database''' ctx = Context(pvdb, interfaces=interfaces) return await ctx.run( log_pv_names=log_pv_names, startup_hook=startup_hook )
[docs]def run(pvdb, *, interfaces=None, log_pv_names=False, startup_hook=None): """ Run an IOC, given its PV database dictionary. A synchronous function that runs server, catches KeyboardInterrupt at exit. Parameters ---------- pvdb : dict The PV database. interfaces : list, optional List of interfaces to listen on. log_pv_names : bool, optional Log PV names at startup. startup_hook : coroutine, optional Hook to call at startup with the ``async_lib`` shim. """ try: return trio.run( functools.partial( start_server, pvdb, interfaces=interfaces, log_pv_names=log_pv_names, startup_hook=startup_hook, )) except KeyboardInterrupt: return