Source code for caproto.asyncio.client

# This is a channel access client implemented using asyncio.

# It builds on the abstractions used in caproto, adding transport and some
# caches for matching requests with responses.
#
# VirtualCircuit: has a caproto.VirtualCircuit, a socket, and some caches.
# Channel: has a VirtualCircuit and a caproto.ClientChannel.
# Context: has a caproto.Broadcaster, a UDP socket, a cache of
#          search results and a cache of VirtualCircuits.
#
import asyncio
import collections
import functools
import getpass
import inspect
import logging
import socket
import threading
import time
import weakref

import caproto as ca

from .. import _constants as constants
from .._utils import (ThreadsafeCounter, batch_requests,
                      get_environment_variables, safe_getsockname)
from ..client import common
from ..client.search_results import (DuplicateSearchResponse, SearchResults,
                                     UnknownSearchResponse)
from .utils import (AsyncioQueue, _CallbackExecutor, _DatagramProtocol,
                    _TaskHandler, _TransportWrapper, _UdpTransportWrapper,
                    get_running_loop)

ch_logger = logging.getLogger('caproto.ch')
search_logger = logging.getLogger('caproto.bcast.search')


[docs]class SharedBroadcaster: ''' A broadcaster client which can be shared among multiple Contexts Parameters ---------- registration_retry_time : float, optional The time, in seconds, between attempts made to register with the repeater. Default is 10. ''' def __init__(self, *, registration_retry_time=10.0): self.environ = get_environment_variables() self.ca_server_port = self.environ['EPICS_CA_SERVER_PORT'] self._registration_retry_time = registration_retry_time self._registration_last_sent = 0 self.broadcaster = ca.Broadcaster(our_role=ca.CLIENT) self.log = self.broadcaster.log self.wrapped_transport = None self.command_queue = AsyncioQueue() self.receive_queue = AsyncioQueue() self._cleanup_event = asyncio.Event() self._search_now = asyncio.Event() self._searching_enabled = asyncio.Event() self._tasks = _TaskHandler() self.results = SearchResults() self.server_protocol_versions = {} # map address to protocol version self.listeners = weakref.WeakSet() # UDP socket broadcasting to CA servers self.udp_sock = None self._essential_tasks_started = False def _ensure_essential_tasks_running(self): "" if self._essential_tasks_started: return self._essential_tasks_started = True self._tasks.create(self._broadcaster_retry_loop()) self._tasks.create(self._broadcaster_receive_loop()) self._tasks.create(self._check_for_unresponsive_servers_loop())
[docs] def add_listener(self, listener): self.listeners.add(listener)
[docs] async def remove_listener(self, listener): try: self.listeners.remove(listener) except KeyError: pass if not self.listeners: await self.disconnect()
def _get_servers_from_listeners(self): 'Map of address to VirtualCircuitManagers' servers = collections.defaultdict(weakref.WeakSet) # Map each server address to VirtualCircuitManagers connected to that # address, across all Contexts ("listeners"). for listener in self.listeners: for (address, _), circuit_manager in listener.circuit_managers.items(): servers[address].add(circuit_manager) return servers
[docs] async def send(self, *commands): """ Process a command and tranport it over the UDP socket. """ bytes_to_send = self.broadcaster.send(*commands) tags = {'role': 'CLIENT', 'our_address': self.broadcaster.client_address, 'direction': '--->>>'} for host_tuple in ca.get_client_address_list(): tags['their_address'] = host_tuple self.log.debug( '%d commands %dB', len(commands), len(bytes_to_send), extra=tags) try: await self.wrapped_transport.sendto(bytes_to_send, host_tuple) except OSError as ex: host, specified_port = host_tuple raise ca.CaprotoNetworkError( f'{ex} while sending {len(bytes_to_send)} bytes to ' f'{host}:{specified_port}') from ex
[docs] async def disconnect(self): 'Disconnect the broadcaster and stop listening' self._registration_last_sent = 0 await self._tasks.cancel_all(wait=True) self.log.debug('Broadcaster: Closing the UDP socket') self.udp_sock = None try: if self.protocol is not None: if self.protocol.transport is not None: self.protocol.transport.close() except OSError: self.log.exception('Broadcaster transport close error') self.log.debug('Broadcaster disconnect complete')
def _should_attempt_registration(self): 'Whether or not a registration attempt should be tried' if self.udp_sock is None: return True if (self.broadcaster.registered or self._registration_retry_time is None): return False since_last_attempt = time.monotonic() - self._registration_last_sent if since_last_attempt < self._registration_retry_time: return False return True async def _create_socket(self): self.udp_sock = ca.bcast_socket(socket_module=socket) # Must bind or getsocketname() will raise on Windows. # See https://github.com/caproto/caproto/issues/514. self.udp_sock.bind(('', 0)) await self._create_transport() async def _create_transport(self): """Create the _UdpTransportWrapper for the UDP socket""" loop = get_running_loop() transport, self.protocol = await loop.create_datagram_endpoint( functools.partial(_DatagramProtocol, parent=self, identifier='client-search', queue=self.receive_queue), sock=self.udp_sock) self.wrapped_transport = _UdpTransportWrapper(transport) self.broadcaster.client_address = safe_getsockname(self.udp_sock)
[docs] async def register(self): "Register this client with the CA Repeater." self._ensure_essential_tasks_running() self._registration_last_sent = time.monotonic() if self.udp_sock is None: await self._create_socket() await self._register() self._searching_enabled.set()
async def _register(self): commands = [self.broadcaster.register('127.0.0.1')] bytes_to_send = self.broadcaster.send(*commands) addr = (ca.get_local_address(), self.environ['EPICS_CA_REPEATER_PORT']) tags = { 'role': 'CLIENT', 'our_address': self.broadcaster.client_address, 'direction': '--->>>', 'their_address': addr, } tags['their_address'] = addr self.broadcaster.log.debug( '%d commands %dB', len(commands), len(bytes_to_send), extra=tags) try: await self.wrapped_transport.sendto(bytes_to_send, addr) except OSError as ex: host, specified_port = addr self.log.exception('%s while sending %d bytes to %s:%d', ex, len(bytes_to_send), host, specified_port) async def _broadcaster_receive_loop(self): 'Loop which consumes receive_queue datagrams from _DatagramProtocol.' queues = collections.defaultdict(list) while True: _, bytes_received, address = await self.receive_queue.async_get() if isinstance(bytes_received, ConnectionResetError): # Win32: "On a UDP-datagram socket this error indicates a # previous send operation resulted in an ICMP Port Unreachable # message." # # https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom self.log.debug( "Broadcaster received ConnectionResetError", exc_info=bytes_received ) await self._create_transport() continue if isinstance(bytes_received, Exception): self.log.exception('Broadcaster receive exception', exc_info=bytes_received) continue try: commands = self.broadcaster.recv(bytes_received, address) except ca.RemoteProtocolError: self.log.exception('Broadcaster received bad packet') continue if commands is ca.DISCONNECTED: break queues.clear() try: self.broadcaster.process_commands(commands) for command in commands: self._process_command(address, command, queues) # Receive commands in 'bundles' (corresponding to the contents # of one UDP datagram). Match SearchResponses to their # SearchRequests, and put (address, (name1, name2, name3, ...)) # into a queue. The receiving end of that queue is held by # Context._process_search_results. Send the search results to # the Contexts that asked for them. for (queue, address), names in queues.items(): queue.put((address, names)) except Exception: self.log.exception('Broadcaster receive loop evaluation') def _process_command(self, addr, command, queues): # if isinstance(command, ca.RepeaterConfirmResponse): # self.registered = True if isinstance(command, ca.Beacon): address = (command.address, command.server_port) self.results.mark_server_alive(address, command.beacon_id) elif isinstance(command, ca.SearchResponse): address = ca.extract_address(command) self.server_protocol_versions[address] = command.version cid = command.cid try: name, queue = self.results.received_search_response( cid, address) except UnknownSearchResponse: self.log.debug('Unknown search response cid=%d', cid) except DuplicateSearchResponse as ex: if len(ex.addresses) <= 1: return name = ex.name accepted_address = ex.addresses[0] other_addresses = ', '.join( '%s:%d' % addr for addr in ex.addresses[1:] ) search_logger.warning( "PV %s with cid %d found on multiple servers. " "Accepted address is %s:%d. Also found on %s", name, cid, *accepted_address, other_addresses, extra={ 'pv': name, 'their_address': accepted_address, 'our_address': self.broadcaster.client_address, }, ) else: queues[(queue, address)].append(name)
[docs] async def search(self, results_queue, *names): "Generate, process, and transport search request(s)" # TODO: where else? is this sufficient? # can't do it on add_listener... self._ensure_essential_tasks_running() if self._should_attempt_registration(): await self.register() # We have have already searched for these names recently. # Filter `pv_names` down to a subset, `needs_search`. use_cached_search, needs_search = self.results.split_cached_results(names) for address, names in use_cached_search.items(): results_queue.put((address, names)) self.results.search( *needs_search, results_queue=results_queue, retirement_deadline=time.monotonic() + common.SEARCH_RETIREMENT_AGE) self._search_now.set()
[docs] def time_since_last_heard(self): """ Map each known server address to seconds since its last message. The time is reset to 0 whenever we receive a TCP message related to user activity *or* a Beacon. Servers are expected to send Beacons at regular intervals. If we do not receive either a Beacon or TCP message, we initiate an Echo over TCP, to which the server is expected to promptly respond. Therefore, the time reported here should not much exceed ``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that environment variable) if the server is healthy. If the server fails to send a Beacon on schedule *and* fails to reply to an Echo, the server is assumed dead. A warning is issued, and all PVs are disconnected to initiate a reconnection attempt. """ return { address: time.monotonic() - t for address, t in self._get_last_heard().items() }
def _get_last_heard(self): """ When is the last time we heard from each server, either via a Beacon or from TCP packets related to user activity or any circuit? Returns ------- dict {addr: last_heard_time} """ beacons = self.results.get_last_beacon_times() last_heard = {} for addr, circuit_managers in self._get_servers_from_listeners().items(): # Aggregate TCP results last_tcp_receipt = (cm.last_tcp_receipt for cm in circuit_managers) last_heard[addr] = max(beacons.get(addr, 0), *last_tcp_receipt) return last_heard async def _check_for_unresponsive_servers_loop(self): self.log.debug('Broadcaster check for unresponsive servers loop is running.') checking = {} while True: now = time.monotonic() cutoff = now - (self.environ['EPICS_CA_CONN_TMO'] + common.BEACON_MARGIN) last_heard = self._get_last_heard() servers = self._get_servers_from_listeners() newly_unresponsive = { addr: circuit_managers for addr, circuit_managers in servers.items() if last_heard.get(addr, 0) < cutoff and addr not in checking } # self.log.debug( # 'Unresponsive checks: \n' # ' servers: %s\n' # ' last heard: %s\n' # ' newly unresponsive: %s\n' # ' checking: %s', # servers, last_heard, newly_unresponsive, checking # ) for address, circuit_managers in newly_unresponsive.items(): # Record that we are checking on this address and set a # deadline for a response. checking[address] = now + constants.RESPONSIVENESS_TIMEOUT tags = { 'role': 'CLIENT', 'their_address': address, 'our_address': self.broadcaster.client_address, 'direction': '--->>>' } self.broadcaster.log.debug( "Missed Beacons from %s:%d. Sending EchoRequest to " "check that server is responsive.", *address, extra=tags) # Send on all circuits. One might be less backlogged with # queued commands than the others and thus able to respond # faster. In the majority of cases there will only be one # circuit per server anyway, so this is a minor distinction. for circuit_manager in circuit_managers: try: await circuit_manager.send(ca.EchoRequest()) except Exception: # Send failed. Server is likely dead, but we'll # catch that shortly; no need to handle it # specially here. pass # Check to see if any of our ongoing checks have resolved or # failed to resolve within the allowed response window. for address, deadline in list(checking.items()): if last_heard[address] > cutoff: # It's alive! checking.pop(address) elif deadline < now: # No circuit connected to the server at this address has # sent Beacons or responded to the EchoRequest. We assume # it is unresponsive. The EPICS specification says the # behavior is undefined at this point. We choose to # disconnect all circuits from that server so that PVs can # attempt to connect to a new server, such as a failover # backup. for circuit_manager in servers[address]: if circuit_manager.connected: circuit_manager.log.warning( "Server at %s:%d is unresponsive. " "Disconnecting circuit manager %r. PVs will " "automatically begin attempting to reconnect " "to a responsive server.", *address, circuit_manager) circuit_manager.disconnect() checking.pop(address) await asyncio.sleep(0.5) self.log.debug('Broadcaster check for unresponsive servers loop has exited.') async def _broadcaster_retry_loop(self): self.log.debug('Broadcaster search-retry thread has started.') interval = common.MIN_RETRY_SEARCHES_INTERVAL time_to_check_on_retirees = ( time.monotonic() + common.RETRY_RETIRED_SEARCHES_INTERVAL) while True: await self._searching_enabled.wait() last_send_time = time.monotonic() if last_send_time >= time_to_check_on_retirees: # Send everything threshold = None else: threshold = last_send_time await self._retry_unanswered_searches( threshold, resend_deadline=last_send_time - interval) wait_time = max(0, interval - (time.monotonic() - last_send_time)) # Double the interval for the next loop. interval = min(2 * interval, common.MAX_RETRY_SEARCHES_INTERVAL) try: await asyncio.wait_for(self._search_now.wait(), timeout=wait_time) except asyncio.TimeoutError: ... else: # New searches have been requested. Reset the interval between # subseqent searches and force a check on the "retirees". time_to_check_on_retirees = last_send_time interval = common.MIN_RETRY_SEARCHES_INTERVAL self._search_now.clear() # self.log.debug('Broadcaster search-retry thread has exited.') async def _retry_unanswered_searches(self, threshold, resend_deadline): """ Periodically (re-)send a SearchRequest for all unanswered searches. Notes ----- Each time new searches are added, the self._search_now Event is set, and we reissue *all* unanswered searches. We then frequently retry the unanswered searches that are younger than SEARCH_RETIREMENT_AGE, backing off from an interval of MIN_RETRY_SEARCHES_INTERVAL to MAX_RETRY_SEARCHES_INTERVAL. The interval is reset to MIN_RETRY_SEARCHES_INTERVAL each time new searches are added. For the searches older than SEARCH_RETIREMENT_AGE, we adopt a slower period to minimize network traffic. We only resend every RETRY_RETIRED_SEARCHES_INTERVAL or, again, whenever new searches are added. """ t = time.monotonic() def _construct_search_requests(items): for search_id, it in items: yield ca.SearchRequest(it.name, search_id, ca.DEFAULT_PROTOCOL_VERSION) it.last_sent = t items = self.results.items_to_retry( threshold, resend_deadline=resend_deadline) requests = _construct_search_requests(items) if items: self.log.debug('Sending %d SearchRequests', len(items)) ver = ca.VersionRequest(0, ca.DEFAULT_PROTOCOL_VERSION) for batch in batch_requests( requests, constants.SEARCH_MAX_DATAGRAM_BYTES - len(ver)): await self.send(ver, *batch)
[docs]class Context: """ Encapsulates the state and connections of a client. Parameters ---------- broadcaster : SharedBroadcaster, optional If None is specified, a fresh one is instantiated. timeout : number or None, optional Number of seconds before a CaprotoTimeoutError is raised. This default can be overridden at the PV level or for any given operation. If unset, the default is 2 seconds. If None, never timeout. A global timeout can be specified via an environment variable ``CAPROTO_DEFAULT_TIMEOUT``. host_name : string, optional uses value of ``socket.gethostname()`` by default client_name : string, optional uses value of ``getpass.getuser()`` by default max_workers : integer, optional Number of worker threaders *per VirtualCircuit* for executing user callbacks. Default is 1. For any number of workers, workers will receive updates in the order which they are received from the server. That is, work on each update will *begin* in sequential order. Work-scheduling internal to the user callback is outside caproto's control. If the number of workers is set to greater than 1, the work on each update may not *finish* in a deterministic order. For example, if workers are writing lines into a file, the only way to guarantee that the lines are ordered properly is to use only one worker. If ordering matters for your application, think carefully before increasing this value from 1. """ def __init__(self, broadcaster=None, *, timeout=common.GLOBAL_DEFAULT_TIMEOUT, host_name=None, client_name=None, max_workers=1): self._user_disconnected = False self.broadcaster = broadcaster or SharedBroadcaster() self.client_name = client_name or getpass.getuser() self.host_name = host_name or socket.gethostname() self.max_workers = max_workers self.timeout = timeout self.log = logging.LoggerAdapter( logging.getLogger('caproto.ctx'), {'role': 'CLIENT'}) self.pv_cache_lock = threading.RLock() # TODO: remove? self.circuit_managers = {} # keyed on ((host, port), priority) self._lock_during_get_circuit_manager = threading.RLock() self.pvs = {} # (name, priority) -> pv # name -> set of pvs --- with varied priority self.pvs_needing_circuits = collections.defaultdict(set) self.broadcaster.add_listener(self) # an event to close and clean up the whole context self.subscriptions_lock = threading.RLock() self.subscriptions_to_activate = collections.defaultdict(set) self.activate_subscriptions_now = asyncio.Event() self._search_results_queue = AsyncioQueue() self._tasks = _TaskHandler() self._essential_tasks_started = False def __repr__(self): return (f"<Context " f"circuits={len(self.circuit_managers)} " f"pvs={len(self.pvs)} " f"idle={len([1 for pv in self.pvs.values() if pv._idle])}>") async def disconnect(self): self._user_disconnected = True try: # disconnect any circuits we have circuits = list(self.circuit_managers.values()) total_circuits = len(circuits) disconnected = False for idx, circuit in enumerate(circuits, 1): if circuit.connected: self.log.debug('Disconnecting circuit %d/%d: %s', idx, total_circuits, circuit) await circuit.disconnect() disconnected = True if disconnected: self.log.debug('All circuits disconnected') finally: # Remove from Broadcaster. await self.broadcaster.remove_listener(self) # clear any state about circuits and search results self.log.debug('Clearing circuit managers') self.circuit_managers.clear() await self._tasks.cancel_all() self.log.debug('Context disconnection complete') async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self.disconnect() def get_circuit_manager(self, address, priority): """ Return a VirtualCircuitManager for this address, priority. (It manages a caproto.VirtualCircuit and a TCP socket.) Make a new one if necessary. """ with self._lock_during_get_circuit_manager: cm = self.circuit_managers.get((address, priority), None) if cm is None or cm.dead.is_set(): version = self.broadcaster.server_protocol_versions[address] circuit = ca.VirtualCircuit(our_role=ca.CLIENT, address=address, priority=priority, protocol_version=version) cm = VirtualCircuitManager(self, circuit) self.circuit_managers[(address, priority)] = cm return cm async def search(self, *names): "Generate, process, transport a search request with the broadcaster" await self.broadcaster.search(self._search_results_queue, *names) def _ensure_essential_tasks_running(self): "" if self._essential_tasks_started: return self._essential_tasks_started = True self._tasks.create(self._process_search_results_loop()) self._tasks.create(self._activate_subscriptions_loop())
[docs] async def get_pvs( self, *names, priority=0, connection_state_callback=None, access_rights_callback=None, timeout=common.CONTEXT_DEFAULT_TIMEOUT): """ Return a list of PV objects. These objects may not be connected at first. Channel creation occurs on a background thread. PVs are uniquely defined by their name and priority. If a PV with the same name and priority is requested twice, the same (cached) object is returned. Any callbacks included here are added to added alongside any existing ones. Parameters ---------- *names : strings any number of PV names priority : integer Used by the server to triage subscription responses when under high load. 0 is lowest; 99 is highest. connection_state_callback : callable Expected signature: ``f(pv, state)`` where ``pv`` is the instance of ``PV`` whose state has changed and ``state`` is a string access_rights_callback : callable Expected signature: ``f(pv, access_rights)`` where ``pv`` is the instance of ``PV`` whose state has changed and ``access_rights`` is a member of the caproto ``AccessRights`` enum timeout : number or None, optional Number of seconds before a CaprotoTimeoutError is raised. This default can be overridden for any specific operation. By default, fall back to the default timeout set by the Context. If None, never timeout. """ if self._user_disconnected: raise common.ContextDisconnectedError( "This Context is no longer usable.") self._ensure_essential_tasks_running() pvs = [] # list of all PV objects to return names_to_search = [] # subset of names that we need to search for for name in names: with self.pv_cache_lock: try: pv = self.pvs[(name, priority)] except KeyError: pv = PV(name, priority, self, None, None, timeout) names_to_search.append(name) self.pvs[(name, priority)] = pv self.pvs_needing_circuits[name].add(pv) if connection_state_callback is not None: pv.connection_state_callback.add_callback( connection_state_callback, run=True) if access_rights_callback is not None: pv.access_rights_callback.add_callback( access_rights_callback, run=True) pvs.append(pv) # Ask the Broadcaster to search for every PV for which we do not # already have an instance. It might already have a cached search # result, but that is the concern of broadcaster.search. if names_to_search: await self.broadcaster.search(self._search_results_queue, *names_to_search) return pvs
async def reconnect(self, keys): # We will reuse the same PV object but use a new cid. names = [] pvs = [] for key in keys: with self.pv_cache_lock: pv = self.pvs[key] pvs.append(pv) name, _ = key names.append(name) # If there is a cached search result for this name, expire it. self.broadcaster.results.invalidate_by_name(name) with self.pv_cache_lock: self.pvs_needing_circuits[name].add(pv) if names: await self.broadcaster.search(self._search_results_queue, *names) async def _activate_subscriptions(self): with self.subscriptions_lock: items = list(self.subscriptions_to_activate.items()) self.subscriptions_to_activate.clear() for cm, subs in items: async def requests(subs): "Yield EventAddRequest commands." requests = [] for sub in subs: command = await sub.compose_command() # compose_command() returns None if this Subscription # is inactive (meaning there are no user callbacks # attached). It will send an EventAddRequest on its own # if/when the user does add any callbacks, so we can # skip it here. if command is not None: requests.append(command) return requests for batch in batch_requests(await requests(subs), common.EVENT_ADD_BATCH_MAX_BYTES): try: await cm.send(*batch) except Exception: if cm.dead.is_set(): self.log.debug( "Circuit died while we were trying to activate" " subscriptions. We will keep attempting this" " until it works.") # When the Context creates a new circuit, we will # end up here again. No big deal. break async def _activate_subscriptions_loop(self): while True: t = time.monotonic() await self._activate_subscriptions() elapsed = time.monotonic() - t wait_time = max(0, (common.RESTART_SUBS_PERIOD - elapsed)) try: await asyncio.wait_for(self.activate_subscriptions_now.wait(), timeout=wait_time) except asyncio.TimeoutError: ... self.activate_subscriptions_now.clear() self.log.debug('Context restart-subscriptions thread exiting') async def _process_search_results_loop(self): # Receive (address, (name1, name2, ...)). The sending side of this # queue is held by the SharedBroadcaster. self.log.debug('Context search-results processing loop has started.') while True: address, names = await self._search_results_queue.async_get() channels_grouped_by_circuit = collections.defaultdict(list) # Assign each PV a VirtualCircuitManager for managing a socket # and tracking circuit state, as well as a ClientChannel for # tracking channel state. for name in names: extra = { 'pv': name, 'their_address': address, 'our_address': self.broadcaster.broadcaster.client_address, 'direction': '--->>>', 'role': 'CLIENT' } search_logger.debug( 'Connecting %s on circuit with %s:%d', name, *address, extra=extra) # There could be multiple PVs with the same name and different # priority. That is what we are looping over here. There could # also be NO PVs with this name that need a circuit, because we # could be receiving a duplicate search response (which we are # supposed to ignore). with self.pv_cache_lock: pvs = self.pvs_needing_circuits.pop(name, set()) for pv in pvs: # Get (make if necessary) a VirtualCircuitManager. This # is where TCP socket creation happens. cm = self.get_circuit_manager(address, pv.priority) circuit = cm.circuit pv.circuit_manager = cm # TODO: NOTE: we are not following the suggestion to use # the same cid as in the search. This simplifies things # between the broadcaster and Context. cid = cm.circuit.new_channel_id() chan = ca.ClientChannel(name, circuit, cid=cid) cm.channels[cid] = chan cm.pvs[cid] = pv channels_grouped_by_circuit[cm].append(chan) pv.circuit_ready.set() # Initiate channel creation with the server. for cm, channels in channels_grouped_by_circuit.items(): commands = [chan.create() for chan in channels] try: await cm.send(*commands) except Exception: if cm.dead.is_set(): self.log.debug( "Circuit died while we were trying to create the " "channel. We will keep attempting this until it " "works." ) # When the Context creates a new circuit, we will end # up here again. No big deal. continue raise self.log.debug('Context search-results processing thread has exited.') async def monitor(self, *pv_names, context=None, data_type='time'): """ Monitor pv_names asynchronously, yielding events as they happen. Parameters ---------- *pv_names : str PV names to monitor. data_type : {'time', 'control', 'native'} The subscription type. Yields ------- event : {'subscription', 'connection'} The event type. context : str or Subscription For a 'connection' event, this is the PV name. For a 'subscription' event, this is the `Subscription` instance. data : str or EventAddResponse For a 'subscription' event, the `EventAddResponse` holds the data and timestamp. For a 'connection' event, this is one of ``{'connected', 'disconnected'}``. """ queue = AsyncioQueue() def value_update(sub, event_add_response): queue.put(('subscription', sub, event_add_response)) def connection_state_callback(pv, state): queue.put(('connection', pv, state)) channels = await self.get_pvs( *pv_names, connection_state_callback=connection_state_callback ) subscriptions = [] for channel in channels: sub = channel.subscribe(data_type=data_type) token = sub.add_callback(value_update) subscriptions.append( dict( channel=channel, sub=sub, token=token, ) ) try: while True: event, context, data = await queue.async_get() yield event, context, data finally: for info in subscriptions.values(): await info['sub'].remove_callback(info['token'])
[docs]class VirtualCircuitManager: """ Encapsulates a VirtualCircuit, a TCP socket, and additional state # This object should never be instantiated directly by user code. It is used # internally by the Context. Its methods may be touched by user code, but # this is rarely necessary. # """ def __init__(self, context, circuit, timeout=common.GLOBAL_DEFAULT_TIMEOUT): self.context = context self.circuit = circuit # a caproto.VirtualCircuit self.log = circuit.log self.channels = {} # map cid to Channel self.pvs = {} # map cid to PV self.ioids = {} # map ioid to Channel and info dict self.subscriptions = {} # map subscriptionid to Subscription self.transport = None self._ioid_counter = ThreadsafeCounter() self._ready = asyncio.Event() self._send_on_connection = [] self._send_lock = asyncio.Lock() self._subscriptionid_counter = ThreadsafeCounter() self.command_queue = AsyncioQueue() self.dead = asyncio.Event() self.last_tcp_receipt = 0.0 self.user_callback_executor = _CallbackExecutor(self.log) self._tags = { 'their_address': self.circuit.address, 'our_address': 'unset:0', 'direction': '<<<---', 'role': repr(self.circuit.our_role) } if self.circuit.states[ca.SERVER] is not ca.IDLE: raise ca.CaprotoRuntimeError("Cannot connect. States are {} " "".format(self.circuit.states)) self._tasks = _TaskHandler() self._tasks.create(self._connection_ready_hook()) self._tasks.create(self._connect(timeout=timeout)) async def _transport_receive_loop(self, transport): while True: try: bytes_received = await transport.recv() except ca.CaprotoNetworkError: bytes_received = b'' self.last_tcp_receipt = time.monotonic() commands, _ = self.circuit.recv(bytes_received) for c in commands: self.command_queue.put(c) if not bytes_received: break async def _connect(self, timeout): """Start the connection and spawn tasks.""" try: reader, writer = await asyncio.wait_for( asyncio.open_connection(*self.circuit.address), timeout=timeout, ) except asyncio.TimeoutError: host, port = self.circuit.address raise ca.CaprotoTimeoutError( f"Circuit with server at {host}:{port} did not connect within " f"{float(timeout):.3}-second timeout." ) self.transport = _TransportWrapper(reader, writer) self._tasks.create(self._transport_receive_loop(self.transport)) self.circuit.our_address = self.transport.getsockname() # This dict is passed to the loggers. self._tags = {'their_address': self.circuit.address, 'our_address': self.circuit.our_address, 'direction': '<<<---', 'role': repr(self.circuit.our_role)} self.log.debug( 'Connected to circuit with address %s:%d', *self.circuit.address, extra=self._tags) await self.send( ca.VersionRequest(self.circuit.priority, ca.DEFAULT_PROTOCOL_VERSION), ca.HostNameRequest(self.context.host_name), ca.ClientNameRequest(self.context.client_name), extra=self._tags) # Ensure we don't get any commands back before sending the above, lest # we confuse the state machine: self._tasks.create(self._command_queue_loop()) # Old versions of the protocol do not send a VersionResponse at TCP # connection time, so set this Event manually rather than waiting for # it to be set by receipt of a VersionResponse. if self.server_protocol_version < 12: self._ready.set() try: await asyncio.wait_for(self._ready.wait(), timeout=timeout) except asyncio.TimeoutError: host, port = self.circuit.address raise ca.CaprotoTimeoutError( f"Circuit with server at {host}:{port} did not connect within " f"{float(timeout):.3}-second timeout.") from None @property def server_protocol_version(self): return self.context.broadcaster.server_protocol_versions[ self.circuit.address] def __repr__(self): return (f"<VirtualCircuitManager circuit={self.circuit} " f"pvs={len(self.pvs)} ioids={len(self.ioids)} " f"subscriptions={len(self.subscriptions)} " f"server_protocol_version={self.server_protocol_version}>") @property def connected(self): return self.circuit.states[ca.CLIENT] is ca.CONNECTED async def _connection_ready_hook(self): await self._ready.wait() to_send = list(self._send_on_connection) self._send_on_connection.clear() for commands, extra in to_send: await self.send(*commands, extra=extra)
[docs] async def send(self, *commands, extra=None): if self.dead.is_set(): raise common.DeadCircuitError() if self.transport is None: self._send_on_connection.append((commands, extra)) return # Turn the crank: inform the VirtualCircuit that these commands will be # send, and convert them to buffers. buffers_to_send = self.circuit.send(*commands, extra=extra) async with self._send_lock: for buff in buffers_to_send: self.transport.writer.write(bytes(buff)) await self.transport.writer.drain()
[docs] async def events_off(self): """ Suspend updates to all subscriptions on this circuit. This may be useful if the server produces updates faster than the client can processs them. """ await self.send(ca.EventsOffRequest())
[docs] async def events_on(self): """ Reactive updates to all subscriptions on this circuit. """ await self.send(ca.EventsOnRequest())
async def _process_command(self, command): try: self.circuit.process_command(command) except ca.CaprotoError as ex: if hasattr(ex, 'channel'): channel = ex.channel self.log.warning('Invalid command %s for Channel %s in state %s', command, channel, channel.states, exc_info=ex) # channel exceptions are not fatal return self.log.exception( 'Invalid command %s for VirtualCircuit %s in state %s', command, self, self.circuit.states) # circuit exceptions are fatal; exit the loop await self.disconnect() return tags = self._tags if command is ca.DISCONNECTED: await self._disconnected() elif isinstance(command, (ca.VersionResponse,)): assert self.connected # double-check that the state machine agrees self._ready.set() elif isinstance(command, (ca.ReadNotifyResponse, ca.ReadResponse, ca.WriteNotifyResponse)): ioid_info = self.ioids.pop(command.ioid) deadline = ioid_info['deadline'] pv = ioid_info['pv'] tags = tags.copy() tags['pv'] = pv.name if deadline is not None and time.monotonic() > deadline: self.log.warning("Ignoring late response with ioid=%d regarding " "PV named %s because " "it arrived %.3f seconds after the deadline " "specified by the timeout.", command.ioid, pv.name, time.monotonic() - deadline) return event = ioid_info.get('event') if event is not None: # If PV.read() or PV.write() are waiting on this response, # they hold a reference to ioid_info. We will use that to # provide the response to them and then set the Event that they # are waiting on. ioid_info['response'] = command event.set() callback = ioid_info.get('callback') if callback is not None: self.user_callback_executor.submit(callback, command) elif isinstance(command, ca.EventAddResponse): try: sub = self.subscriptions[command.subscriptionid] except KeyError: # This subscription has been removed. We assume that this # response was in flight before the server processed our # unsubscription. pass else: # This method submits jobs to the Contexts's # ThreadPoolExecutor for user callbacks. sub.process(command) tags = tags.copy() tags['pv'] = sub.pv.name elif isinstance(command, ca.AccessRightsResponse): pv = self.pvs[command.cid] pv.access_rights_changed(command.access_rights) tags = tags.copy() tags['pv'] = pv.name elif isinstance(command, ca.EventCancelResponse): # TODO Any way to add the pv name to tags here? ... elif isinstance(command, ca.CreateChanResponse): pv = self.pvs[command.cid] chan = self.channels[command.cid] self._search_results.mark_channel_created( pv.name, self.circuit.address) with pv.component_lock: pv.channel = chan pv.channel_ready.set() pv.connection_state_changed('connected', chan) tags = tags.copy() tags['pv'] = pv.name elif isinstance(command, (ca.ServerDisconnResponse, ca.ClearChannelResponse)): pv = self.pvs[command.cid] pv.connection_state_changed('disconnected', None) tags = tags.copy() tags['pv'] = pv.name # NOTE: pv remains valid until server goes down # TODO: or do we not assume the server will remove it? # self._search_results.mark_channel_disconnected( # pv.name, self.circuit.address) elif isinstance(command, ca.EchoResponse): # The important effect here is that it will have updated # self.last_tcp_receipt when the bytes flowed through # self.received. ... if not isinstance(command, ca.Message): return # Log each message with the above-gathered tags tags['bytesize'] = len(command) self.log.debug("%r", command, extra=tags) async def _command_queue_loop(self): command = None while True: try: command = await self.command_queue.async_get() await self._process_command(command) if command is ca.DISCONNECTED: self.log.debug('Command queue loop exiting') break except asyncio.CancelledError: break except Exception: self.log.exception('Circuit command evaluation failed: %r', command) continue # async with self.new_command_condition: # self.new_command_condition.notify_all() @property def _search_results(self): return self.context.broadcaster.results async def _disconnected(self, *, reconnect=True): # Ensure that this method is idempotent. if self.dead.is_set(): return tags = {'their_address': self.circuit.address} self.log.debug('Virtual circuit with address %s:%d has disconnected.', *self.circuit.address, extra=tags) # Update circuit state. This will be reflected on all PVs, which # continue to hold a reference to this disconnected circuit. self.circuit.disconnect() for pv in self.pvs.values(): pv.channel_ready.clear() pv.circuit_ready.clear() self.dead.set() for ioid_info in self.ioids.values(): # Un-block any calls to PV.read() or PV.write() that are waiting on # responses that we now know will never arrive. They will check on # circuit health and raise appropriately. event = ioid_info.get('event') if event is not None: event.set() # Remove server + channels marked as created from the search results: self._search_results.mark_server_disconnected(self.circuit.address) for pv in self.pvs.values(): pv.connection_state_changed('disconnected', None) # Remove VirtualCircuitManager from Context. # This will cause all future calls to Context.get_circuit_manager() # to create a fresh VirtualCiruit and VirtualCircuitManager. self.context.circuit_managers.pop(self.circuit.address, None) if self.transport is not None: try: self.transport.close() except OSError: self.log.exception( 'VirtualCircuitManager transport close error' ) finally: self.transport = None tags = {'their_address': self.circuit.address} if reconnect: # Kick off attempt to reconnect all PVs via fresh circuit(s). self.log.debug( 'Kicking off reconnection attempts for %d PVs disconnected ' 'from %s:%d....', len(self.channels), *self.circuit.address, extra=tags) await self.context.reconnect( ((chan.name, chan.circuit.priority) for chan in self.channels.values()) ) else: self.log.debug('Not attempting reconnection', extra=tags)
[docs] async def disconnect(self): await self._disconnected(reconnect=False) self.log.debug("Shutting down ThreadPoolExecutor for user callbacks", extra={'their_address': self.circuit.address}) await self.user_callback_executor.shutdown() if self.transport is None: return self.log.debug('Circuit manager disconnected by user')
def ensure_connected(func): ''' Ensure connected decorator. Parameters ---------- func : coroutine ''' assert inspect.iscoroutinefunction(func) @functools.wraps(func) async def inner(self, *args, **kwargs): if isinstance(self, PV): pv = self elif isinstance(self, Subscription): pv = self.pv else: raise ca.CaprotoTypeError( "ensure_connected is intended to decorate " "methods of PV and Subscription.") # timeout may be decremented during disconnection-retry loops below. # Keep a copy of the original 'raw_timeout' for use in error messages. raw_timeout = timeout = kwargs.get('timeout', pv.timeout) if timeout is not None: deadline = time.monotonic() + timeout async with pv._in_use: # If needed, reconnect. Do this inside the lock so that we don't # try to do this twice. (No other threads that need this lock # can proceed until the connection is ready anyway!) if pv._idle: # The Context should have been maintaining a working circuit # for us while this was idle. We just need to re-create the # Channel. try: await asyncio.wait_for(pv.circuit_ready.wait(), timeout=timeout) except asyncio.TimeoutError: raise ca.CaprotoTimeoutError( f"{pv} could not connect within " f"{float(raw_timeout):.3}-second timeout.") from None with pv.component_lock: cm = pv.circuit_manager cid = cm.circuit.new_channel_id() chan = ca.ClientChannel(pv.name, cm.circuit, cid=cid) cm.channels[cid] = chan cm.pvs[cid] = pv await pv.circuit_manager.send(chan.create(), extra={'pv': pv.name}) self._idle = False # increment the usage at the very end in case anything # goes wrong in the block of code above this. pv._usages += 1 try: for _ in range(common.CIRCUIT_DEATH_ATTEMPTS): # On each iteration, subtract the time we already spent on any # previous attempts. if timeout is not None: timeout = deadline - time.monotonic() try: await asyncio.wait_for(pv.channel_ready.wait(), timeout=timeout) except asyncio.TimeoutError: raise ca.CaprotoTimeoutError( f"{pv} could not connect within " f"{float(raw_timeout):.3}-second timeout.") from None if timeout is not None: timeout = deadline - time.monotonic() kwargs['timeout'] = timeout cm = pv.circuit_manager try: return await func(self, *args, **kwargs) except common.DeadCircuitError: # Something in func tried operate on the circuit after # it died. The context will automatically build us a # new circuit. Try again. self.log.debug('Caught DeadCircuitError. ' 'Retrying %s.', func.__name__) continue except (TimeoutError, asyncio.TimeoutError): # The circuit may have died after func was done calling # methods on it but before we received some response we # were expecting. The context will automatically build # us a new circuit. Try again. if cm.dead.is_set(): self.log.debug('Caught TimeoutError due to dead ' 'circuit. ' 'Retrying %s.', func.__name__) continue # The circuit is fine -- this is a real error. raise finally: async with pv._in_use: pv._usages -= 1 pv._in_use.notify_all() return inner
[docs]class PV: """ Represents one PV, specified by a name and priority. This object may exist prior to connection and persists across any subsequent re-connections. This object should never be instantiated directly by user code; rather it should be created by calling the ``get_pvs`` method on a ``Context`` object. """ def __init__(self, name, priority, context, connection_state_callback, access_rights_callback, timeout): """ These must be instantiated by a Context, never directly. """ self.name = name self.priority = priority self.context = context self.access_rights = None # will be overwritten with AccessRights self.log = logging.LoggerAdapter(ch_logger, {'pv': self.name, 'role': 'CLIENT'}) # Use this lock whenever we touch circuit_manager or channel. self.component_lock = threading.RLock() self.circuit_ready = asyncio.Event() self.channel_ready = asyncio.Event() self.connection_state_callback = CallbackHandler(self) self.access_rights_callback = CallbackHandler(self) self._timeout = timeout if connection_state_callback is not None: self.connection_state_callback.add_callback( connection_state_callback, run=True) if access_rights_callback is not None: self.access_rights_callback.add_callback( access_rights_callback, run=True) self._circuit_manager = None self._channel = None self.subscriptions = {} self._idle = False self._in_use = asyncio.Condition() self._usages = 0 @property def timeout(self): """ Effective default timeout. Valid values are: * CONTEXT_DEFAULT_TIMEOUT (fall back to Context.timeout) * a floating-point number * None (never timeout) """ if self._timeout is common.CONTEXT_DEFAULT_TIMEOUT: return self.context.timeout return self._timeout @timeout.setter def timeout(self, val): self._timeout = val @property def circuit_manager(self): return self._circuit_manager @circuit_manager.setter def circuit_manager(self, val): with self.component_lock: self._circuit_manager = val @property def channel(self): return self._channel @channel.setter def channel(self, val): with self.component_lock: self._channel = val
[docs] def access_rights_changed(self, rights): self.access_rights = rights self.access_rights_callback.process(self, rights)
[docs] def connection_state_changed(self, state, channel): self.log.info('connection state changed to %s.', state) self.connection_state_callback.process(self, state) if state == 'disconnected': for sub in self.subscriptions.values(): with sub.callback_lock: if sub.callbacks: sub.needs_reactivation = True if state == 'connected': cm = self.circuit_manager ctx = cm.context with ctx.subscriptions_lock: for sub in self.subscriptions.values(): with sub.callback_lock: if sub.needs_reactivation: ctx.subscriptions_to_activate[cm].add(sub) sub.needs_reactivation = False
def __repr__(self): if self._idle: state = "(idle)" elif self.circuit_manager is None or self.circuit_manager.dead.is_set(): state = "(searching....)" else: state = (f"address={self.circuit_manager.circuit.address}, " f"circuit_state=" f"{self.circuit_manager.circuit.states[ca.CLIENT]}") if self.connected: state += f", channel_state={self.channel.states[ca.CLIENT]}" else: state += " (creating...)" return f"<PV name={self.name!r} priority={self.priority} {state}>" @property def connected(self): channel = self.channel if channel is None: return False return channel.states[ca.CLIENT] is ca.CONNECTED
[docs] @ensure_connected async def wait_for_connection(self, *, timeout=common.PV_DEFAULT_TIMEOUT): """ Wait for this PV to be connected. Parameters ---------- timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. """
# NOTE: the logic happens in the wrapper.
[docs] async def go_idle(self): """Request to clear this Channel to reduce load on client and server. A new Channel will be automatically, silently created the next time any method requiring a connection is called. Thus, this saves some memory in exchange for making the next request a bit slower, as it has to redo the handshake with the server first. If there are any subscriptions with callbacks, this request will be ignored. If the PV is in the process of connecting, this request will be ignored. If there are any actions in progress (read, write) this request will be processed when they are complete. """ for sub in self.subscriptions.values(): if sub.callbacks: return async with self._in_use: if not self.channel_ready.is_set(): return # Wait until no other methods that employ @self.ensure_connected # are in process. await self._in_use.wait_for(lambda: self._usages == 0) # No other threads are using the connection, and we are holding the # self._in_use Condition's lock, so we can safely close the # connection. The next thread to acquire the lock will re-connect # after it acquires the lock. try: self.channel_ready.clear() await self.circuit_manager.send(self.channel.clear(), extra={'pv': self.name}) except OSError: # the socket is dead-dead, do nothing ... self._idle = True
[docs] @ensure_connected async def read(self, *, wait=True, callback=None, timeout=common.PV_DEFAULT_TIMEOUT, data_type=None, data_count=None, notify=True): """Request a fresh reading. Can do one or both of: - Block while waiting for the response, and return it. - Pass the response to callback, with or without blocking. Parameters ---------- wait : boolean If True (default) block until a matching response is received from the server. Raises CaprotoTimeoutError if that response is not received within the time specified by the `timeout` parameter. callback : callable or None Called with the response as its argument when received. timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. notify: boolean, optional Send a ``ReadNotifyRequest`` instead of a ``ReadRequest``. True by default. """ if timeout is common.PV_DEFAULT_TIMEOUT: timeout = self.timeout cm, chan = self._circuit_manager, self._channel ioid = cm._ioid_counter() command = chan.read(ioid=ioid, data_type=data_type, data_count=data_count, notify=notify) # Stash the ioid to match the response to the request. event = asyncio.Event() ioid_info = dict(event=event, pv=self, request=command) if callback is not None: ioid_info['callback'] = callback cm.ioids[ioid] = ioid_info deadline = time.monotonic() + timeout if timeout is not None else None ioid_info['deadline'] = deadline await cm.send(command, extra={'pv': self.name}) if not wait: return # The circuit_manager will put a reference to the response into # ioid_info and then set event. try: await asyncio.wait_for(event.wait(), timeout=timeout) except asyncio.TimeoutError: host, port = cm.circuit.address raise ca.CaprotoTimeoutError( f"Server at {host}:{port} did " f"not respond to attempt to read channel named " f"{self.name!r} within {float(timeout):.3}-second timeout. " f"The ioid of the expected response is {ioid}." ) if cm.dead.is_set(): # This circuit has died sometime during this function call. # The exception raised here will be caught by # @ensure_connected, which will retry the function call a # in hopes of getting a working circuit until our `timeout` has # been used up. raise common.DeadCircuitError() return ioid_info['response']
[docs] @ensure_connected async def write(self, data, *, wait=True, callback=None, timeout=common.PV_DEFAULT_TIMEOUT, notify=None, data_type=None, data_count=None): """ Write a new value. Optionally, request confirmation from the server. Can do one or both of: - Block while waiting for the response, and return it. - Pass the response to callback, with or without blocking. Parameters ---------- data : str, int, or float or any Iterable of these Value(s) to write. wait : boolean If True (default) block until a matching WriteNotifyResponse is received from the server. Raises CaprotoTimeoutError if that response is not received within the time specified by the `timeout` parameter. callback : callable or None Called with the WriteNotifyResponse as its argument when received. timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. notify : boolean or None If None (default), set to True if wait=True or callback is set. Can be manually set to True or False. Will raise ValueError if set to False while wait=True or callback is set. data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Write specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. """ if timeout is common.PV_DEFAULT_TIMEOUT: timeout = self.timeout cm, chan = self._circuit_manager, self._channel if notify is None: notify = (wait or callback is not None) ioid = cm._ioid_counter() command = chan.write(data, ioid=ioid, notify=notify, data_type=data_type, data_count=data_count) if notify: event = asyncio.Event() ioid_info = dict(event=event, pv=self, request=command) if callback is not None: ioid_info['callback'] = callback cm.ioids[ioid] = ioid_info deadline = time.monotonic() + timeout if timeout is not None else None ioid_info['deadline'] = deadline # do not need to lock this, locking happens in circuit command else: if wait or callback is not None: raise ca.CaprotoValueError( "Must set notify=True in order to use `wait` or `callback`" " because, without a notification of 'put-completion' from" " the server, there is nothing to wait on or to trigger a" " callback.") await cm.send(command, extra={'pv': self.name}) if not wait: return # The circuit_manager will put a reference to the response into # ioid_info and then set event. try: await asyncio.wait_for(event.wait(), timeout=timeout) except asyncio.TimeoutError: if cm.dead.is_set(): # This circuit has died sometime during this function call. # The exception raised here will be caught by # @ensure_connected, which will retry the function call a # in hopes of getting a working circuit until our `timeout` has # been used up. raise common.DeadCircuitError() host, port = cm.circuit.address raise ca.CaprotoTimeoutError( f"Server at {host}:{port} did " f"not respond to attempt to write to channel named " f"{self.name!r} within {float(timeout):.3}-second timeout. " f"The ioid of the expected response is {ioid}." ) return ioid_info['response']
[docs] def subscribe(self, data_type=None, data_count=None, low=0.0, high=0.0, to=0.0, mask=None): """ Start a new subscription to which user callback may be added. Parameters ---------- data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. low, high, to : float, optional deprecated by Channel Access, not yet implemented by caproto mask : SubscriptionType, optional Subscribe to selective updates. Returns ------- subscription : Subscription Examples -------- Define a subscription. >>> sub = pv.subscribe() Add a user callback. The subscription will be transparently activated (i.e. an ``EventAddRequest`` will be sent) when the first user callback is added. >>> sub.add_callback(my_func) Multiple callbacks may be added to the same subscription. >>> sub.add_callback(another_func) See the docstring for :class:`Subscription` for more. """ # A Subscription is uniquely identified by the Signature created by its # args and kwargs. bound = common.SUBSCRIBE_SIG.bind(data_type, data_count, low, high, to, mask) key = tuple(bound.arguments.items()) try: sub = self.subscriptions[key] except KeyError: sub = Subscription(self, data_type, data_count, low, high, to, mask) self.subscriptions[key] = sub # The actual EPICS messages will not be sent until the user adds # callbacks via sub.add_callback(user_func). return sub
[docs] async def unsubscribe_all(self): "Clear all subscriptions. (Remove all user callbacks from them.)" for sub in self.subscriptions.values(): await sub.clear()
# @ensure_connected
[docs] def time_since_last_heard(self, timeout=common.PV_DEFAULT_TIMEOUT): """ Seconds since last message from the server that provides this channel. The time is reset to 0 whenever we receive a TCP message related to user activity *or* a Beacon. Servers are expected to send Beacons at regular intervals. If we do not receive either a Beacon or TCP message, we initiate an Echo over TCP, to which the server is expected to promptly respond. Therefore, the time reported here should not much exceed ``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that environment variable) if the server is healthy. If the server fails to send a Beacon on schedule *and* fails to reply to an Echo, the server is assumed dead. A warning is issued, and all PVs are disconnected to initiate a reconnection attempt. Parameters ---------- timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. """ address = self.circuit_manager.circuit.address return self.context.broadcaster.time_since_last_heard()[address]
class CallbackHandler: def __init__(self, pv): # NOTE: not a WeakValueDictionary or WeakSet as PV is unhashable... self.callbacks = {} self.pv = pv self._callback_id = 0 self.callback_lock = threading.RLock() self._last_call_values = None def add_callback(self, func, run=False): def removed(_): self.remove_callback(cb_id) # defined below inside the lock if inspect.ismethod(func): ref = weakref.WeakMethod(func, removed) else: # TODO: strong reference to non-instance methods? ref = weakref.ref(func, removed) with self.callback_lock: cb_id = self._callback_id self._callback_id += 1 self.callbacks[cb_id] = ref if run and self._last_call_values is not None: with self.callback_lock: args, kwargs = self._last_call_values self.process(*args, **kwargs) return cb_id def remove_callback(self, token): # TODO: async confusion: # sync CallbackHandler.remove_callback # async Subscription.remove_callback with self.callback_lock: self.callbacks.pop(token, None) def process(self, *args, **kwargs): """ This is a fast operation that submits jobs to the Context's ThreadPoolExecutor and then returns. """ to_remove = [] with self.callback_lock: callbacks = list(self.callbacks.items()) self._last_call_values = (args, kwargs) for cb_id, ref in callbacks: callback = ref() if callback is None: to_remove.append(cb_id) continue self.pv.circuit_manager.user_callback_executor.submit( callback, *args, **kwargs) with self.callback_lock: for remove_id in to_remove: self.callbacks.pop(remove_id, None)
[docs]class Subscription(CallbackHandler): """ Represents one subscription, specified by a PV and configurational parameters It may fan out to zero, one, or multiple user-registered callback functions. This object should never be instantiated directly by user code; rather it should be made by calling the ``subscribe()`` method on a ``PV`` object. """ def __init__(self, pv, data_type, data_count, low, high, to, mask): super().__init__(pv) # Stash everything, but do not send any EPICS messages until the first # user callback is attached. self.data_type = data_type self.data_count = data_count self.low = low self.high = high self.to = to self.mask = mask self.subscriptionid = None self.most_recent_response = None self.needs_reactivation = False # This is related to back-compat for user callbacks that have the old # signature, f(response). self.__wrapper_weakrefs = set() @property def log(self): return self.pv.log def __repr__(self): return f"<Subscription to {self.pv.name!r}, id={self.subscriptionid}>" async def __aiter__(self): queue = AsyncioQueue() async def iter_callback(sub, value): await queue.async_put(value) sid = self.add_callback(iter_callback) try: while True: item = await queue.async_get() yield item finally: await self.remove_callback(sid) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self.clear() def _subscribe(self, timeout=common.PV_DEFAULT_TIMEOUT): """This is called automatically after the first callback is added. """ cm = self.pv.circuit_manager if cm is None: # We are currently disconnected (perhaps have not yet connected). # When the PV connects, this subscription will be added. with self.callback_lock: self.needs_reactivation = True else: # We are (or very recently were) connected. In the rare event # where cm goes dead in the interim, subscription will be retried # by the activation loop. ctx = cm.context with ctx.subscriptions_lock: ctx.subscriptions_to_activate[cm].add(self) ctx.activate_subscriptions_now.set() @ensure_connected async def compose_command(self, timeout=common.PV_DEFAULT_TIMEOUT): "This is used by the Context to re-subscribe in bulk after dropping." # TODO: compose_command async due to ensure_connected? with self.callback_lock: if not self.callbacks: return None cm, chan = self.pv._circuit_manager, self.pv._channel subscriptionid = cm._subscriptionid_counter() command = chan.subscribe(data_type=self.data_type, data_count=self.data_count, low=self.low, high=self.high, to=self.to, mask=self.mask, subscriptionid=subscriptionid) subscriptionid = command.subscriptionid self.subscriptionid = subscriptionid # The circuit_manager needs to know the subscriptionid so that it can # route responses to this request. cm.subscriptions[subscriptionid] = self return command
[docs] async def clear(self): """ Remove all callbacks. """ with self.callback_lock: for cb_id in list(self.callbacks): await self.remove_callback(cb_id)
# Once self.callbacks is empty, self.remove_callback calls # self._unsubscribe for us. async def _unsubscribe(self): """ This is automatically called if the number of callbacks goes to 0. """ with self.callback_lock: if self.subscriptionid is None: # Already unsubscribed. return subscriptionid = self.subscriptionid self.subscriptionid = None self.most_recent_response = None self.pv.circuit_manager.subscriptions.pop(subscriptionid, None) chan = self.pv.channel if chan and chan.states[ca.CLIENT] is ca.CONNECTED: try: command = self.pv.channel.unsubscribe(subscriptionid) except ca.CaprotoKeyError: pass except ca.CaprotoValueError: self.log.exception('TODO') else: await self.pv.circuit_manager.send(command, extra={'pv': self.pv.name}) def process(self, command): # TODO here i think we can decouple PV update rates and callback # handling rates, if desirable, to not bog down performance. # As implemented below, updates are blocking further messages from # the CA servers from processing. (-> ThreadPool, etc.) pv = self.pv super().process(self, command) self.log.debug("%r: %r", pv.name, command) self.most_recent_response = command
[docs] def add_callback(self, func): """ Add a callback to receive responses. Parameters ---------- func : callable Expected signature: ``func(sub, response)``. The signature ``func(response)`` is also supported for backward-compatibility but will issue warnings. Support will be removed in a future release of caproto. Returns ------- token : int Integer token that can be passed to :meth:`remove_callback`. .. versionchanged:: 0.5.0 Changed the expected signature of ``func`` from ``func(response)`` to ``func(sub, response)``. """ # Handle func with signature func(response) for back-compat. with self.callback_lock: was_empty = not self.callbacks cb_id = super().add_callback(func) most_recent_response = self.most_recent_response if was_empty: # This is the first callback. Set up a subscription, which # should elicit a response from the server soon giving the # current value to this func (and any other funcs added in the # mean time). self._subscribe() else: # This callback is piggy-backing onto an existing subscription. # Send it the most recent response, unless we are still waiting # for that first response from the server. if most_recent_response is not None: circuit_manager = self.pv.circuit_manager if circuit_manager is not None: circuit_manager.user_callback_executor.submit( func, self, most_recent_response) return cb_id
[docs] async def remove_callback(self, token): """ Remove callback using token that was returned by :meth:`add_callback`. Parameters ---------- token : integer Token returned by :meth:`add_callback`. """ with self.callback_lock: super().remove_callback(token) if not self.callbacks: # Go dormant. await self._unsubscribe() self.most_recent_response = None self.needs_reactivation = False