Source code for caproto.pva.asyncio.server

import asyncio
import functools
import socket
import sys

import caproto as ca

from ...asyncio.server import AsyncioAsyncLayer, ServerExit
from ...asyncio.utils import _DatagramProtocol, _TaskHandler, _TransportWrapper
from ..server.common import Context as _Context
from ..server.common import VirtualCircuit as _VirtualCircuit


[docs]class VirtualCircuit(_VirtualCircuit): """Wraps a caproto.pva.VirtualCircuit with an asyncio server.""" TaskCancelled = asyncio.CancelledError def __init__(self, circuit, client, context, *, loop=None): if loop is None: loop = asyncio.get_event_loop() self.loop = loop class SockWrapper: def __init__(self, loop, client): self.loop = loop self.client = client def getsockname(self): return self.client.getsockname() async def recv(self, nbytes): return (await self.loop.sock_recv(self.client, nbytes)) self._raw_lock = asyncio.Lock() self._raw_client = client super().__init__(circuit, SockWrapper(loop, client), context) self.QueueFull = asyncio.QueueFull self.message_queue = asyncio.Queue(ca.MAX_COMMAND_BACKLOG, loop=self.loop) self.events_on = asyncio.Event(loop=self.loop) self.subscription_queue = asyncio.Queue( ca.MAX_TOTAL_SUBSCRIPTION_BACKLOG, loop=self.loop) self.tasks = _TaskHandler() self._sub_task = None
[docs] async def get_from_sub_queue(self, timeout=None): """ Get one item from the subscription queue. Notes ----- The superclass expects us to implement this in our own way due to timeouts. """ future = asyncio.ensure_future(self.subscription_queue.get()) try: return await asyncio.wait_for(future, timeout, loop=self.loop) except asyncio.TimeoutError: return None
[docs] async def send(self, *messages): if self.connected: buffers_to_send = self.circuit.send(*messages, extra=self._tags) # lock to make sure a AddEvent does not write bytes # to the socket while we are sending async with self._raw_lock: await self.loop.sock_sendall(self._raw_client, b''.join(buffers_to_send))
[docs] async def run(self): self.tasks.create(self.message_queue_loop()) self._sub_task = self.tasks.create(self.subscription_queue_loop())
async def _start_write_task(self, handle_write): """ Start a write handler, and return the task. """ return self.tasks.create(handle_write()) async def _on_disconnect(self): await super()._on_disconnect() self._raw_client.close() if self._sub_task is not None: await self.tasks.cancel(self._sub_task) self._sub_task = None
[docs]class Context(_Context): CircuitClass = VirtualCircuit async_layer = None ServerExit = ServerExit TaskCancelled = asyncio.CancelledError def __init__(self, pvdb, interfaces=None, *, loop=None): super().__init__(pvdb, interfaces) self.message_bundle_queue = asyncio.Queue() self.subscription_queue = asyncio.Queue() if loop is None: loop = asyncio.get_event_loop() self.loop = loop self.async_layer = AsyncioAsyncLayer() self.server_tasks = _TaskHandler() self.tcp_sockets = dict()
[docs] async def server_accept_loop(self, sock): sock.listen() while True: client_sock, addr = await self.loop.sock_accept(sock) self.server_tasks.create(self.tcp_handler(client_sock, addr))
@property def guid(self) -> str: """The server GUID.""" raw_guid = self.broadcaster.guid return ''.join(hex(ord(c))[2:] for c in raw_guid).upper()
[docs] async def run(self, *, log_pv_names=False): """ Start the server. Parameters ---------- log_pv_names : bool, optional Log all PV names to `self.log` after starting up. """ self.log.info('Asyncio server starting up...') self.log.info('Server GUID is: 0x%s', self.guid) async def make_socket(interface, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setblocking(False) s.bind((interface, port)) return s self.port, self.tcp_sockets = await self._bind_tcp_sockets_with_consistent_port_number( make_socket) self.broadcaster.server_port = self.port tasks = _TaskHandler() for interface, sock in self.tcp_sockets.items(): self.log.info("Listening on %s:%d", interface, self.port) self.broadcaster.server_addresses.append((interface, self.port)) tasks.create(self.server_accept_loop(sock)) class ConnectedTransportWrapper: """Make an asyncio transport something you can call send on.""" def __init__(self, transport, address): self.transport = transport self.address = address async def send(self, bytes_to_send): try: self.transport.sendto(bytes_to_send, self.address) except OSError as exc: host, port = self.address raise ca.CaprotoNetworkError( f"Failed to send to {host}:{port}") from exc def close(self): return self.transport.close() reuse_port = sys.platform not in ('win32', ) and hasattr(socket, 'SO_REUSEPORT') for address in ca.get_beacon_address_list(protocol=ca.Protocol.PVAccess): transport, _ = await self.loop.create_datagram_endpoint( functools.partial(_DatagramProtocol, parent=self, recv_func=self._datagram_received), remote_addr=address, allow_broadcast=True, reuse_port=reuse_port) wrapped_transport = ConnectedTransportWrapper(transport, address) self.beacon_socks[address] = (interface, # TODO; this is incorrect wrapped_transport) for interface in self.interfaces: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) # Python says this is unsafe, but we need it to have # multiple servers live on the same host. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if reuse_port: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setblocking(False) sock.bind((interface, self.pva_broadcast_port)) transport, _ = await self.loop.create_datagram_endpoint( functools.partial(_DatagramProtocol, parent=self, recv_func=self._datagram_received), sock=sock) self.udp_socks[interface] = _TransportWrapper(transport) self.log.debug('UDP socket bound on %s:%d', interface, self.pva_broadcast_port) tasks.create(self.broadcaster_queue_loop()) tasks.create(self.subscription_queue_loop()) tasks.create(self.broadcast_beacon_loop()) async_lib = AsyncioAsyncLayer() for name, method in self.startup_methods.items(): self.log.debug('Calling startup method %r', name) tasks.create(method(async_lib)) self.log.info('Server startup complete.') if log_pv_names: self.log.info('PVs available:\n%s', '\n'.join(self.pvdb)) try: await asyncio.gather(*tasks.tasks) except asyncio.CancelledError: self.log.info('Server task cancelled. Will shut down.') await tasks.cancel_all() await self.server_tasks.cancel_all() for circuit in self.circuits: await circuit.tasks.cancel_all() return except Exception: self.log.exception('Server error. Will shut down') raise finally: self.log.info('Server exiting....') shutdown_tasks = [] for name, method in self.shutdown_methods.items(): self.log.debug('Calling shutdown method %r', name) task = self.loop.create_task(method(async_lib)) shutdown_tasks.append(task) await asyncio.gather(*shutdown_tasks) for sock in self.tcp_sockets.values(): sock.close() for sock in self.udp_socks.values(): sock.close() for _interface, sock in self.beacon_socks.values(): sock.close()
def _datagram_received(self, pair): bytes_received, address = pair try: messages = self.broadcaster.recv(bytes_received, address) except ca.RemoteProtocolError: self.log.exception('Broadcaster received bad packet') else: self.message_bundle_queue.put_nowait((address, messages))
[docs]async def start_server(pvdb, *, interfaces=None, log_pv_names=False): '''Start an asyncio server with a given PV database''' ctx = Context(pvdb, interfaces) ret = await ctx.run(log_pv_names=log_pv_names) return ret
[docs]def run(pvdb, *, interfaces=None, log_pv_names=False): """ A synchronous function that wraps start_server and exits cleanly. """ loop = asyncio.get_event_loop() task = loop.create_task( start_server(pvdb, interfaces=interfaces, log_pv_names=log_pv_names)) try: loop.run_until_complete(task) except KeyboardInterrupt: ... finally: task.cancel() loop.run_until_complete(task)