import collections
import dataclasses
import enum
import logging
import typing
from collections import defaultdict, deque
from typing import Dict, Optional, Tuple, Type
import caproto as ca
import caproto.pva as pva
from caproto import (CaprotoNetworkError, CaprotoRuntimeError,
RemoteProtocolError, get_environment_variables)
from .._data import DataWithBitSet
from .._dataclass import get_pv_structure, pva_dataclass
from .._fields import BitSet
from .._functools_compat import singledispatchmethod
from .._messages import Message, MonitorSubcommand, Subcommand
class DisconnectedCircuit(Exception):
...
class LoopExit(Exception):
...
[docs]@pva_dataclass
class ServerStatus:
running: bool = True
caproto_version: str = str(ca.__version__)
[docs]class AuthOperation(enum.Enum):
"""
Operations which allow for granular authorization on a per-PV basis.
"""
read = enum.auto()
read_interface = enum.auto()
write = enum.auto()
call = enum.auto()
[docs]@dataclasses.dataclass(frozen=True)
class SubscriptionSpec:
'''
Subscription specification used to key all subscription updates.
Attributes
----------
db_entry : DataWrapperInterface
The database entry.
bitset : BitSet
The bitset to monitor.
options : tuple
Options for the monitor (tuple(options_dict.items()))
'''
db_entry: object
bitset: BitSet
options: tuple
[docs]@dataclasses.dataclass(frozen=True)
class Subscription:
'''
An individual subscription from a client.
Attributes
----------
spec : SubscriptionSpec
The subscription specification information.
circuit : VirtualCircuit
The associated virtual circuit.
channel : ServerChannel
The associated channel.
ioid : int
The I/O identifier / request ID.
'''
spec: SubscriptionSpec
circuit: 'VirtualCircuit'
channel: pva.ServerChannel
ioid: int
[docs]class VirtualCircuit:
"""
The base VirtualCircuit class.
Servers are expected to subclass from this, including additional
attributes, noted in the Attributes section.
Attributes
----------
QueueFull : class
Must be implemented in the subclass. (TODO details)
message_queue : QueueInterface
Must be implemented in the subclass.
subscription_queue : QueueInterface
Must be implemented in the subclass.
get_from_sub_queue : method
Must be implemented in the subclass.
_start_write_task : method
Must be implemented in the subclass.
"""
context: 'Context'
connected: bool
circuit: pva.ServerVirtualCircuit
log: logging.Logger
client: object # socket or similar (TODO socket interface)
most_recent_updates: Dict
_tags: Dict[str, str]
subscriptions: typing.DefaultDict[SubscriptionSpec,
typing.Deque[Subscription]]
def __init__(self,
circuit: pva.ServerVirtualCircuit,
client,
context: 'Context'
):
self.connected = True
self.circuit = circuit # a caproto.pva.ServerVirtualCircuit
self.circuit.our_address = client.getsockname()
self.log = circuit.log
self.client = client
self.context = context
self.subscriptions = defaultdict(deque)
self.most_recent_updates = {}
# This dict is passed to the loggers.
self._tags = {
'their_address': self.circuit.address,
'our_address': self.circuit.our_address,
'direction': '<<<---',
'role': repr(self.circuit.our_role),
}
self.authorization_info = {}
async def _start_write_task(self, handle_write):
"""
Start a write handler, and return the task.
Must be implemented by the subclass, and must return a cancellable task
instance (API like asyncio, currently).
"""
raise NotImplementedError()
async def _on_disconnect(self):
"""Executed when disconnection detected"""
if not self.connected:
return
self.connected = False
queue = self.context.subscription_queue
for sub_spec, subs in self.subscriptions.items():
for sub in subs:
self.context.subscriptions[sub_spec].remove(sub)
await sub_spec.db_entry.unsubscribe(queue, sub)
if not self.context.subscriptions[sub_spec]:
self.context.subscriptions.pop(sub_spec)
self.subscriptions.clear()
[docs] async def send(self, *messages):
"""
Process a message and tranport it over the TCP socket for this circuit.
"""
if self.connected:
buffers_to_send = self.circuit.send(*messages, extra=self._tags)
# send bytes over the wire using some caproto utilities
async with self._raw_lock:
await ca.async_send_all(buffers_to_send, self.client.sendmsg)
[docs] async def recv(self):
"""
Receive bytes over TCP and append them to this circuit's buffer.
"""
try:
bytes_received = await self.client.recv(4096)
except (ConnectionResetError, ConnectionAbortedError):
bytes_received = []
for message, bytes_consumed in self.circuit.recv(bytes_received):
try:
await self.message_queue.put(message)
except self.QueueFull:
# This client is fast and we are not keeping up. Better to kill
# the circuit (and let the client try again) than to let the
# whole server be OOM-ed.
self.log.warning(
"Circuit %r has a large backlog of received messages, "
"evidently cannot keep up with a fast client. Disconnecting"
" circuit to avoid letting consume all available memory.",
self
)
await self._on_disconnect()
raise DisconnectedCircuit()
if not bytes_received:
await self._on_disconnect()
raise DisconnectedCircuit()
def _get_ids_from_message(self, message: Message) -> Tuple[int, int]:
"""Returns (client_chid, server_chid) given a message"""
server_chid = getattr(message, 'server_chid', None)
client_chid = getattr(message, 'server_chid', None)
if server_chid is not None and client_chid is None:
client_chid = self.circuit.channels_sid[server_chid].client_chid
elif client_chid is not None and server_chid is None:
server_chid = self.circuit.channels[client_chid].server_chid
return client_chid, server_chid
async def _message_queue_iteration(self, message):
"""
Coroutine which evaluates one item from the circuit message queue.
1. Dispatch and validate through caproto.pva.VirtualCircuit.process_message
- Upon server failure, respond to the client with
caproto.ErrorResponse
2. Update Channel state if applicable.
"""
try:
self.log.debug("%r", message, extra=self._tags)
self.circuit.process_command(message)
except RemoteProtocolError:
client_chid, server_chid = self._get_ids_from_message(message)
if client_chid is not None:
try:
raise
# await self.send(ca.ServerDisconnResponse(client_chid=client_chid))
except Exception:
self.log.exception(
"Client broke the protocol in a recoverable way, but "
"channel disconnection of client_chid=%d server_chid=%d failed.", client_chid,
server_chid)
raise LoopExit('Recoverable protocol error failure')
else:
self.log.exception(
"Client broke the protocol in a recoverable way. "
"Disconnected channel client_chid=%d server_chid=%d but keeping the "
"circuit alive.", client_chid, server_chid)
return
else:
self.log.exception(
"Client broke the protocol in an unrecoverable way.")
# TODO: Kill the circuit.
raise LoopExit('Unrecoverable protocol error')
except Exception:
self.log.exception('Circuit message queue evaluation failed')
# Internal error - ignore for now
return
if message is pva.DISCONNECTED:
raise DisconnectedCircuit()
try:
response = await self._process_message(message)
return response
except Exception:
if not self.connected:
if not isinstance(message, pva.ChannelDestroyRequest):
self.log.exception(
'Server error after client disconnection: %s', message
)
raise LoopExit('Server error after client disconnection')
client_chid, server_chid = self._get_ids_from_message(message)
chan, _ = self._get_db_entry_from_message(message)
self.log.exception(
'Server failed to process message (%r): %s',
chan.name, message
)
# if client_chid is not None:
# error_message = f'Python exception: {type(ex).__name__} {ex}'
# return [
# pva.todo(message, client_chid, status=ca.CAStatus.ECA_INTERNAL,
# error_message=error_message)
# ]
raise
async def _newly_connected(self):
"""
Just connected to the client. Send an authentication request.
"""
byte_order = self.circuit.set_byte_order(
pva.EndianSetting.use_server_byte_order
)
req = self.circuit.validate_connection(
buffer_size=32767,
registry_size=32767,
authorization_options=self.context.authentication_methods
)
await self.send(byte_order, req)
[docs] async def subscription_queue_loop(self):
"""
Subscription queue loop.
This is the final spot where we ship updates off to the client.
"""
def cull_messages(messages):
"""
Ensure at the last possible moment that we don't send responses for
Subscriptions that have been canceled at some time after the
response was queued.
"""
all_subscription_ids = set(
sub.ioid
for subs in self.subscriptions.values()
for sub in subs
)
return (
message for message in messages
if message.ioid in all_subscription_ids
)
while True:
try:
ref = await self.subscription_queue.get()
# ref = await self.get_from_sub_queue(timeout=ca.HIGH_LOAD_TIMEOUT)
# message = ref() # TODO: weakref
await self.send(*cull_messages([ref]))
except self.TaskCancelled:
break
except DisconnectedCircuit:
await self._on_disconnect()
self.circuit.disconnect()
await self.context.circuit_disconnected(self)
break
except Exception:
self.log.exception('Subscription update send failure %s',
locals().get('message', '(no message)'))
[docs] async def message_queue_loop(self):
"""Reference implementation of the message queue loop
Note
----
Assumes self.message_bundle_queue functions as an async queue with
awaitable .get()
Async library implementations can (and should) reimplement this.
Coroutine which evaluates one item from the circuit message queue.
"""
await self._newly_connected()
try:
while True:
message = await self.message_queue.get()
response = await self._message_queue_iteration(message)
if response is not None:
await self.send(*response)
except DisconnectedCircuit:
await self._on_disconnect()
self.circuit.disconnect()
await self.context.circuit_disconnected(self)
except self.TaskCancelled:
...
except LoopExit:
...
def _get_db_entry_from_message(self, message):
"""Return a database entry from message, determined by the server id"""
chan = self.circuit._get_channel_from_message(message)
db_entry = self.context[chan.name]
return chan, db_entry
@singledispatchmethod
async def _process_message(self, message):
# Fall-through for non-registered items
if message is pva.DISCONNECTED:
raise DisconnectedCircuit()
self.log.error("Unhandled %r", message, extra=self._tags)
return []
@_process_message.register
async def _(self, message: pva.ConnectionValidationResponse):
self.authorization_info.update(**{
'method': message.auth_nz,
'data': message.data.data,
})
return [self.circuit.validated_connection()]
@_process_message.register
async def _(self, message: pva.SearchRequest):
...
# TODO message.channels -> searchreply
return []
@_process_message.register
async def _(self, message: pva.CreateChannelRequest):
to_send = []
for info in message.channels:
try:
cid = info['id']
name = info['channel_name']
chan = self.circuit.channels[cid]
except KeyError:
self.log.debug('Client requested invalid channel name: %s',
name)
to_send.append(
chan.create(
sid=0,
status=pva.Status.create_error(
message=f'Invalid channel name {name}',
),
)
)
else:
to_send.append(
chan.create(sid=self.circuit.new_channel_id())
)
return to_send
@_process_message.register
async def _(self, message: pva.ChannelFieldInfoRequest):
chan, db_entry = self._get_db_entry_from_message(message)
data = await db_entry.authorize(
operation=AuthOperation.read_interface,
authorization=self.authorization_info,
)
data = await db_entry.read(None)
return [chan.read_interface(ioid=message.ioid, interface=data)]
@_process_message.register
async def _(self, message: pva.ChannelGetRequest):
subcommand = Subcommand(message.subcommand)
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
response: pva.ChannelGetResponse
if Subcommand.INIT in subcommand:
try:
await db_entry.authorize(
AuthOperation.read,
authorization=self.authorization_info,
request=message.pv_request,
)
data = await db_entry.read(
request=message.pv_request,
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
response = chan.read(
ioid=message.ioid, interface=None,
status=pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
),
)
else:
ioid_info['pv_request'] = message.pv_request
ioid_info['interface'] = data
response = chan.read(ioid=message.ioid, interface=data)
ioid_info['init_request'] = message
# Reusable response message for this ioid:
ioid_info['response'] = response
return [response]
if Subcommand.GET in subcommand or subcommand == Subcommand.DEFAULT:
# NOTE: we'll only get here if INIT succeeded, where the
# authentication happens
data = await db_entry.read(
ioid_info['init_request'].pv_request
)
pv_data = DataWithBitSet(data=data,
bitset=BitSet({0}), # TODO
)
# TODO: check if interface has changed
response = ioid_info['response']
if subcommand == Subcommand.GET:
response.to_get(pv_data=pv_data)
else:
response.to_default(pv_data=pv_data)
return [response]
@_process_message.register
async def _(self, message: pva.ChannelPutRequest):
subcommand = Subcommand(message.subcommand)
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
response: pva.ChannelPutResponse
if Subcommand.INIT in subcommand:
try:
interface = await db_entry.authorize(
AuthOperation.write,
request=message.pv_request,
authorization=self.authorization_info,
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
interface = None
status = pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
)
else:
status = pva.Status.create_success()
response = chan.write(
ioid=message.ioid,
status=status,
put_structure_if=interface,
)
ioid_info['pv_request'] = message.pv_request
ioid_info['interface'] = interface
ioid_info['response'] = response
ioid_info['write_task'] = None
return [response.to_init(put_structure_if=interface)]
if Subcommand.GET in subcommand:
# This is pretty much a pva-get, using the pvrequest from the
# put_init
response = ioid_info['response']
try:
pv_request = ioid_info['pv_request']
read_data = await db_entry.read(pv_request)
data = DataWithBitSet(
data=read_data,
bitset=BitSet({0}), # TODO
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
response.status = pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
)
data = None
else:
response.status = pva.Status.create_success()
return [response.to_get(data=data)]
if subcommand == Subcommand.DEFAULT or subcommand == Subcommand.DESTROY:
async def handle_write():
try:
response = ioid_info['response']
await db_entry.write(message.put_data)
except self.TaskCancelled:
self.log.debug(
'Write request by %s(%s) cancelled: %s => %r',
self.authorization_info['method'],
self.authorization_info['data'],
chan.name,
message.put_data.data,
)
response.status = pva.Status.create_error(
message='Cancelled',
)
except Exception as ex:
self.log.exception(
'Write request by %s(%s) failed: %r',
self.authorization_info['method'],
self.authorization_info['data'],
message)
response.status = pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
)
else:
response.status = pva.Status.create_success()
finally:
ioid_info['write_task'] = None
await self.send(response.to_default())
ioid_info['write_task'] = await self._start_write_task(handle_write)
@_process_message.register
async def _(self, message: pva.ChannelMonitorRequest):
subcommand = MonitorSubcommand(message.subcommand)
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
response: pva.ChannelMonitorResponse
if subcommand == MonitorSubcommand.INIT:
try:
data = await db_entry.authorize(
AuthOperation.read,
authorization=self.authorization_info,
request=message.pv_request,
)
bitset, options = message.pv_request.to_bitset_and_options(
data
)
spec = SubscriptionSpec(
db_entry=db_entry, bitset=bitset, options=tuple(options.items())
)
sub = Subscription(
circuit=self, channel=chan, spec=spec, ioid=message.ioid
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
response = chan.subscribe(
ioid=message.ioid, interface=None,
status=pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
),
)
else:
response = chan.subscribe(ioid=message.ioid, interface=data)
ioid_info['pv_request'] = message.pv_request
ioid_info['interface'] = data
ioid_info['init_request'] = message
# Reusable response message for this ioid:
ioid_info['response'] = response
ioid_info['sub'] = sub
ioid_info['monitor_state'] = MonitorSubcommand.INIT
ioid_info['pipeline_count'] = None
return [response]
if MonitorSubcommand.START in subcommand:
if ioid_info['monitor_state'] in {MonitorSubcommand.INIT,
MonitorSubcommand.STOP}:
sub: Subscription = ioid_info['sub']
data = await db_entry.subscribe(
queue=self.context.subscription_queue,
sub=sub,
)
self.subscriptions[sub.spec].append(sub)
self.context.subscriptions[sub.spec].append(sub)
ioid_info['monitor_state'] = MonitorSubcommand.START
# It's not impossible this send could happen -after- an update
response = ioid_info['response']
response.to_default(
pv_data=pva.DataWithBitSet(
bitset=BitSet(sub.spec.bitset),
interface=get_pv_structure(data),
data=data
),
overrun_bitset=BitSet({})
)
return [response]
if MonitorSubcommand.PIPELINE in subcommand:
# TODO: need to track the number of monitors that happen
...
has_destroy = MonitorSubcommand.DESTROY in subcommand
if subcommand in {MonitorSubcommand.STOP} or has_destroy:
if ioid_info['monitor_state'] in {MonitorSubcommand.START}:
await db_entry.unsubscribe(
queue=self.context.subscription_queue,
sub=ioid_info['sub'],
)
ioid_info['monitor_state'] = MonitorSubcommand.STOP
self.subscriptions[sub.spec].remove(sub)
self.context.subscriptions[sub.spec].remove(sub)
@_process_message.register
async def _(self, message: pva.ChannelRpcRequest):
subcommand = Subcommand(message.subcommand)
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
response: pva.ChannelRpcResponse
if subcommand == Subcommand.INIT:
try:
await db_entry.authorize(
AuthOperation.call,
authorization=self.authorization_info,
request=message.pv_request,
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
response = chan.rpc(
ioid=message.ioid,
status=pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
),
)
else:
ioid_info['pv_request'] = message.pv_request
response = chan.rpc(ioid=message.ioid)
ioid_info['init_request'] = message
# Reusable response message for this ioid:
ioid_info['response'] = response
return [response]
if subcommand == Subcommand.DEFAULT or subcommand == Subcommand.DESTROY:
response = ioid_info['response']
try:
pv_response = await db_entry.call(
request=ioid_info['init_request'].pv_request,
data=message.pv_data,
)
except Exception as ex:
self.log.exception('Message response failure %s (%s)',
message, subcommand)
response.to_default(
pv_response=None,
status=pva.Status.create_error(
message=f'{ex.__class__.__name__}: {ex}',
)
)
else:
response.to_default(
pv_response=pva.FieldDescAndData(data=pv_response),
status=pva.Status.create_success(),
)
return [response]
@_process_message.register
async def _(self, message: pva.ChannelDestroyRequest):
"""This is a request to destroy a **channel**."""
# TODO: cleanup
chan, db_entry = self._get_db_entry_from_message(message)
return [chan.disconnect()]
@_process_message.register
async def _(self, message: pva.ChannelRequestDestroy):
"""This is a request to destroy a **request**."""
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
task = ioid_info.pop('write_task', None)
if task is not None:
task.cancel()
return []
@_process_message.register
async def _(self, message: pva.ChannelRequestCancel):
# TODO: this layer should handle canceling the operation
chan, db_entry = self._get_db_entry_from_message(message)
ioid_info = self.circuit.ioids[message.ioid]
task = ioid_info.pop('write_task', None)
if task is not None:
task.cancel()
return []
@_process_message.register
async def _(self, message: pva.EchoRequest):
return [pva.EchoResponse()]
[docs]class Context(typing.Mapping):
# subscription_queue: 'QueueInterface'
port: Optional[int]
# TODO
def __init__(self, pvdb, interfaces=None):
if interfaces is None:
interfaces = ca.get_server_address_list(
protocol=ca.Protocol.PVAccess)
self.interfaces = interfaces
self.udp_socks = {} # map each interface to a UDP socket for searches
self.beacon_socks = {} # map each interface to a UDP socket for beacons
self.pvdb = pvdb
self.log = logging.getLogger('caproto.pva.ctx')
self.addresses = []
self.circuits = set()
self.authentication_methods = {'anonymous', 'ca'}
self.environ = get_environment_variables()
# pva_server_port: the default tcp/udp port from the environment
self.pva_server_port = self.environ['EPICS_PVAS_SERVER_PORT']
self.pva_broadcast_port = self.environ['EPICS_PVAS_BROADCAST_PORT']
self.broadcaster = pva.Broadcaster(
our_role=ca.SERVER,
broadcast_port=self.pva_broadcast_port,
server_port=None, # TBD
)
# the specific tcp port in use by this server
self.port = None
self.log.debug(
'EPICS_PVA_SERVER_PORT set to %d. This is the UDP port to be used'
'for searches.'
)
self.subscription_queue = None
self.subscriptions = defaultdict(deque)
async def _core_broadcaster_loop(self, udp_sock):
while True:
try:
bytes_received, address = await udp_sock.recvfrom(4096 * 16)
except ConnectionResetError:
self.log.exception('UDP server connection reset')
await self.async_layer.library.sleep(0.1)
continue
if bytes_received:
await self._broadcaster_recv_datagram(bytes_received, address)
async def _broadcaster_recv_datagram(self, bytes_received, address):
try:
messages = self.broadcaster.recv(bytes_received, address)
except RemoteProtocolError:
self.log.exception('Broadcaster received bad packet')
else:
await self.message_bundle_queue.put((address, messages))
[docs] async def broadcaster_queue_loop(self):
"""
Reference broadcaster queue loop implementation
Note
----
Assumes self.message_bundle_queue functions as an async queue with
awaitable .get()
Async library implementations can (and should) reimplement this.
"""
while True:
try:
addr, messages = await self.message_bundle_queue.get()
await self._broadcaster_queue_iteration(addr, messages)
except self.TaskCancelled:
break
except Exception as ex:
self.log.exception('Broadcaster message queue evaluation failed',
exc_info=ex)
continue
def __iter__(self):
# Implemented to support __getitem__ below
return iter(self.pvdb)
def __getitem__(self, pvname):
return self.pvdb[pvname]
def __len__(self):
return len(self.pvdb)
async def _broadcaster_queue_iteration(self, addr, messages):
self.broadcaster.process_commands(messages)
found_pv_to_cid = {}
saw_empty_channel_list = False
for message in messages:
if isinstance(message, pva.SearchRequest):
if len(message.channels) == 0:
# This is apparently a special "I'm looking for servers"
# message
saw_empty_channel_list = True
for channel in message.channels:
try:
channel['id']
name = channel['channel_name']
self[name]
except KeyError:
...
else:
found_pv_to_cid[name] = channel['id']
if found_pv_to_cid or saw_empty_channel_list:
search_replies = [
self.broadcaster.search_response(
pv_to_cid=found_pv_to_cid,
)
]
bytes_to_send = self.broadcaster.send(*search_replies)
# TODO: why send this back on all sockets?
for udp_sock in self.udp_socks.values():
try:
await udp_sock.sendto(bytes_to_send, addr)
except OSError as exc:
host, port = addr
raise CaprotoNetworkError(f"Failed to send to {host}:{port}") from exc
[docs] async def broadcast_beacon_loop(self):
if self.environ.get('CAPROTO_PVA_BEACON_DISABLE', '') == '1':
self.log.warning('Beacons disabled for debugging purposes')
return
self.log.debug('Will send beacons to %r',
[f'{h}:{p}' for h, p in self.beacon_socks.keys()])
# "RECOMMENDED" by the PVA spec (~15Hz at startup)
MIN_BEACON_PERIOD = 0.07
BEACON_BACKOFF = 2
max_beacon_period = self.environ['EPICS_PVAS_BEACON_PERIOD']
beacon_period = MIN_BEACON_PERIOD
server_status = ServerStatus()
while True:
beacon = self.broadcaster.beacon(server_status=server_status)
bytes_to_send = self.broadcaster.send(beacon)
for address, (interface, sock) in self.beacon_socks.items():
try:
await sock.send(bytes_to_send)
except IOError:
self.log.exception(
"Failed to send beacon to %r. Try setting "
"EPICS_PVAS_AUTO_BEACON_ADDR_LIST=no and "
"EPICS_PVAS_BEACON_ADDR_LIST=<addresses>.", address
)
if beacon_period < max_beacon_period:
beacon_period = min(max_beacon_period,
beacon_period * BEACON_BACKOFF)
await self.async_layer.library.sleep(beacon_period)
[docs] async def circuit_disconnected(self, circuit):
"""Notification from circuit that its connection has closed"""
self.circuits.discard(circuit)
async def _bind_tcp_sockets_with_consistent_port_number(self, make_socket):
"""
Find a random port number that is free on all `self.interfaces`, and
get a bound TCP socket with that port number on each interface. The
argument `make_socket` is expected to be a coroutine with the signature
`make_socket(interface, port)` that does whatever library-specific
incantation is necessary to return a bound socket or raise an IOError.
"""
tcp_sockets = {} # maps interface to bound socket
stashed_ex = None
for port in ca.random_ports(100, try_first=self.pva_server_port):
try:
for interface in self.interfaces:
s = await make_socket(interface, port)
tcp_sockets[interface] = s
except IOError as ex:
stashed_ex = ex
for s in tcp_sockets.values():
s.close()
tcp_sockets.clear()
else:
break
else:
raise CaprotoRuntimeError(
'No available ports and/or bind failed'
) from stashed_ex
return port, tcp_sockets
[docs] async def tcp_handler(self, client, addr):
"""Handler for each new TCP client to the server"""
cavc = pva.ServerVirtualCircuit(ca.SERVER, addr, None)
circuit = self.CircuitClass(cavc, client, self)
self.circuits.add(circuit)
self.log.info('Connected to new client at %s:%d (total: %d).', *addr,
len(self.circuits))
await circuit.run()
try:
while True:
try:
await circuit.recv()
except DisconnectedCircuit:
await self.circuit_disconnected(circuit)
break
except KeyboardInterrupt as ex:
self.log.debug('TCP handler received KeyboardInterrupt')
raise self.ServerExit() from ex
self.log.info('Disconnected from client at %s:%d (total: %d).', *addr,
len(self.circuits))
[docs] def stop(self):
...
@property
def startup_methods(self):
"""Notify all instances of the server startup."""
return {
name: instance.server_startup
for name, instance in self.pvdb.items()
if getattr(instance, 'server_startup', None) is not None
}
@property
def shutdown_methods(self):
"""Notify all instances of the server shutdown."""
return {
name: instance.server_shutdown
for name, instance in self.pvdb.items()
if getattr(instance, 'server_shutdown', None) is not None
}
[docs] async def subscription_queue_loop(self):
"""
Reference implementation of the subscription queue loop.
Note
----
Assumes self.subscription-queue functions as an async queue with
awaitable .get()
Async library implementations can (and should) reimplement this
coroutine which evaluates one item from the circuit command queue.
"""
while True:
# This queue receives updates that match the SubscriptionSpec of
# one or more subscriptions.
item = await self.subscription_queue.get()
try:
await self._subscription_queue_iteration(**item)
except Exception:
self.log.exception(
'Subscription publishing failed for %s',
item.get('sub', None)
)
raise # TODO: remove
async def _subscription_queue_iteration(
self, sub: Subscription, interface, data, bitset: BitSet):
"""
Called on every item from the Context subscription queue.
"""
circuit = sub.circuit
cls: Type[pva.ChannelMonitorResponse] = circuit.circuit.messages[
pva.ApplicationCommand.MONITOR
]
monitor_update = cls(ioid=sub.ioid).to_default(
pv_data=pva.DataWithBitSet(
bitset=bitset,
interface=get_pv_structure(interface), # TODO
data=data
),
overrun_bitset=BitSet({})
)
try:
await circuit.subscription_queue.put(monitor_update)
except circuit.QueueFull:
# We have hit the overall max for subscription backlog.
circuit.log.warning(
"Critically high EventAddResponse load. Dropping all "
"queued responses on this circuit."
)
circuit.subscription_queue.clear()
# TODO
# circuit.unexpired_updates.clear()
[docs]class DataWrapperBase:
"""
A base class to wrap dataclasses and support caproto-pva's server API.
Parameters
----------
name : str
The associated name of the data.
data : PvaStruct
The dataclass holding the data.
"""
_sub_queues: typing.DefaultDict[typing.FrozenSet[int], typing.Deque]
def __init__(self, name: str, data):
self.data = data
self.name = name
# This is a dict keyed on queues that will receive subscription
# updates, where each queue belongs to a Context.
self._sub_queues = collections.defaultdict(collections.deque)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self.name}>'
[docs] async def authorize(self,
operation: AuthOperation, *,
authorization,
request=None):
"""
Authenticate `operation`, given `authorization` information.
In the event of successful authorization, a dataclass defining the data
contained here must be returned.
In the event of a failed authorization, `AuthenticationError` or
similar should be raised.
Returns
-------
data
Raises
------
AuthenticationError
"""
return self.data
[docs] async def read(self, request):
"""A bare ``read`` (``get``) implementation."""
return self.data
[docs] async def write(self, update: pva.DataWithBitSet):
"""A bare ``write`` (``put``) implementation."""
await self.commit(update.data)
[docs] async def call(self, request: pva.PVRequest, data: pva.FieldDescAndData):
"""A bare ``call`` (``RPC``) implementation."""
[docs] async def subscribe(self, queue, sub: Subscription):
"""
Add a subscription from the server.
It is unlikely this would need customization in a subclass.
Parameters
----------
queue : QueueInterface
The queue to send updates to.
sub : Subscription
Subscription information.
Returns
-------
data
"""
self._sub_queues[sub.spec.bitset].append((sub, queue))
return self.data
[docs] async def unsubscribe(self, queue, sub: Subscription):
"""
Remove an already-added subscription.
It is unlikely this would need customization in a subclass.
Parameters
----------
queue : QueueInterface
The queue used for subscriptions.
sub : Subscription
Subscription information.
"""
self._sub_queues[sub.spec.bitset].remove((sub, queue))
if not self._sub_queues[sub.spec.bitset]:
self._sub_queues.pop(sub.spec.bitset)
[docs] async def commit(self, changes: dict):
"""
Commit `changes` to the local dataclass and publish monitors.
It is unlikely this would need customization in a subclass.
Parameters
----------
changes : dict
A nested dictionary of key to value, indicating changes to be
made to the underlying data.
"""
changed_bitset = pva.fill_dataclass(self.data, changes)
# And publish indicating which bits of information have changed:
await self._publish(changed_bitset)
async def _publish(self, changed_bitset: BitSet):
"""
Publish already-committed changes.
It is unlikely this would need customization in a subclass.
Parameters
----------
changed_bitset : BitSet
This indicates which fields have changed.
"""
# A misplaced description regarding subscription flow:
# Data written and .commit() called ->
# -> _publish
# Based on what parts changed
# -> Context.subscription_queue
# Create monitor update message
# -> VirtualCircuit.subscription_queue
# Potentially batch messages, remove unsubscribed items
# -> Ship remaining messages to client
data = None
for frozen_bitset, queues in self._sub_queues.items():
matched_bitset = changed_bitset & frozen_bitset
if matched_bitset:
if data is None:
# Only create the dict if actually needed
data = dataclasses.asdict(self.data)
# TODO/FIXME/BUG: numpy arrays will be shallow copied
for sub, queue in queues:
# if request matches change
# TODO: respect options here?
item = dict(
sub=sub,
bitset=matched_bitset,
data=data,
interface=self.data
)
await queue.put(item)