Source code for caproto.threading.client

# Regarding threads...
# The SharedBroadcaster has:
# - UDP socket SelectorThread
# - UDP command processing
# - forever retrying search requests for disconnected PV
# The Context has:
# - process search results
# - TCP socket SelectorThread
# - restart subscriptions
# The VirtualCircuit has:
# - ThreadPoolExecutor for processing user callbacks on read, write, subscribe
import array
import concurrent.futures
import errno
import functools
import getpass
import inspect
import logging
import random
import selectors
import socket
import threading
import time
import weakref
from collections import defaultdict, deque
from inspect import Parameter, Signature
from queue import Empty, Queue

import caproto as ca

from .._constants import (MAX_ID, RESPONSIVENESS_TIMEOUT,
                          SEARCH_MAX_DATAGRAM_BYTES, STALE_SEARCH_EXPIRATION)
from .._utils import (CaprotoError, CaprotoKeyError, CaprotoNetworkError,
                      CaprotoRuntimeError, CaprotoTimeoutError,
                      CaprotoTypeError, CaprotoValueError, ThreadsafeCounter,
                      adapt_old_callback_signature, batch_requests,
                      safe_getsockname, socket_bytes_available)
from ..client import common

ch_logger = logging.getLogger('caproto.ch')
search_logger = logging.getLogger('caproto.bcast.search')


class DeadCircuitError(CaprotoError):
    ...


def ensure_connected(func):
    @functools.wraps(func)
    def inner(self, *args, **kwargs):
        if isinstance(self, PV):
            pv = self
        elif isinstance(self, Subscription):
            pv = self.pv
        else:
            raise CaprotoTypeError("ensure_connected is intended to decorate "
                                   "methods of PV and Subscription.")
        # timeout may be decremented during disconnection-retry loops below.
        # Keep a copy of the original 'raw_timeout' for use in error messages.
        raw_timeout = timeout = kwargs.get('timeout', pv.timeout)
        if timeout is not None:
            deadline = time.monotonic() + timeout
        with pv._in_use:
            # If needed, reconnect. Do this inside the lock so that we don't
            # try to do this twice. (No other threads that need this lock
            # can proceed until the connection is ready anyway!)
            if pv._idle:
                # The Context should have been maintaining a working circuit
                # for us while this was idle. We just need to re-create the
                # Channel.
                ready = pv.circuit_ready.wait(timeout=timeout)
                if not ready:
                    raise CaprotoTimeoutError(
                        f"{pv} could not connect within "
                        f"{float(raw_timeout):.3}-second timeout.")
                with pv.component_lock:
                    cm = pv.circuit_manager
                    cid = cm.circuit.new_channel_id()
                    chan = ca.ClientChannel(pv.name, cm.circuit, cid=cid)
                    cm.channels[cid] = chan
                    cm.pvs[cid] = pv
                    pv.circuit_manager.send(chan.create(), extra={'pv': pv.name})
                    self._idle = False
            # increment the usage at the very end in case anything
            # goes wrong in the block of code above this.
            pv._usages += 1

        try:
            for _ in range(common.CIRCUIT_DEATH_ATTEMPTS):
                # On each iteration, subtract the time we already spent on any
                # previous attempts.
                if timeout is not None:
                    timeout = deadline - time.monotonic()
                ready = pv.channel_ready.wait(timeout=timeout)
                if not ready:
                    raise CaprotoTimeoutError(
                        f"{pv} could not connect within "
                        f"{float(raw_timeout):.3}-second timeout.")
                if timeout is not None:
                    timeout = deadline - time.monotonic()
                    kwargs['timeout'] = timeout

                cm = pv.circuit_manager
                try:
                    return func(self, *args, **kwargs)
                except DeadCircuitError:
                    # Something in func tried operate on the circuit after
                    # it died. The context will automatically build us a
                    # new circuit. Try again.
                    self.log.debug('Caught DeadCircuitError. '
                                   'Retrying %s.', func.__name__)
                    continue
                except TimeoutError:
                    # The circuit may have died after func was done calling
                    # methods on it but before we received some response we
                    # were expecting. The context will automatically build
                    # us a new circuit. Try again.
                    if cm.dead.is_set():
                        self.log.debug('Caught TimeoutError due to dead '
                                       'circuit. '
                                       'Retrying %s.', func.__name__)
                        continue
                    # The circuit is fine -- this is a real error.
                    raise

        finally:
            with pv._in_use:
                pv._usages -= 1
                pv._in_use.notify_all()
    return inner


class ThreadingClientException(CaprotoError):
    ...


class DisconnectedError(ThreadingClientException):
    ...


class ContextDisconnectedError(ThreadingClientException):
    ...


class SelectorThread:
    """
    This is used internally by the Context and the VirtualCircuitManager.
    """
    def __init__(self, *, parent=None):
        self.thread = None  # set by the `start` method
        self._close_event = threading.Event()
        self.selector = selectors.DefaultSelector()

        self._register_event = threading.Event()
        self._socket_map_lock = threading.RLock()
        self.objects = weakref.WeakValueDictionary()
        self.socket_to_id = {}

        self._register_sockets = {}  # {socket: object_id}
        self._unregister_sockets = set()
        self._object_id = 0
        self._socket_count = 0

        if parent is not None:
            # Stop the selector if the parent goes out of scope
            self._parent = weakref.ref(parent, lambda obj: self.stop())

    @property
    def running(self):
        '''Selector thread is running'''
        return not self._close_event.is_set()

    def stop(self):
        self._close_event.set()

        # In case we're waiting for the first socket to be added:
        self._register_event.set()

    def start(self):
        if self._close_event.is_set():
            raise CaprotoRuntimeError("Cannot be restarted once stopped.")
        self.thread = threading.Thread(target=self, daemon=True,
                                       name='selector')
        self.thread.start()

    def add_socket(self, sock, target_obj):
        assert isinstance(sock, socket.socket)
        with self._socket_map_lock:
            if sock in self.socket_to_id:
                raise CaprotoValueError('Socket already added')

            sock.setblocking(False)

            # assumption: only one sock per object
            self._object_id += 1
            self.objects[self._object_id] = target_obj
            self.socket_to_id[sock] = self._object_id
            self._register_sockets[sock] = self._object_id
            weakref.finalize(target_obj,
                             lambda sock=sock: self.remove_socket(sock))
            # self.log.debug('Socket %s was added (obj %s)', sock, target_obj)

    def remove_socket(self, sock):
        with self._socket_map_lock:
            if sock not in self.socket_to_id:
                return
            obj_id = self.socket_to_id.pop(sock)
            obj = self.objects.pop(obj_id, None)
            if obj is not None:
                obj.received(b'', None)

            try:
                # removed before it was even added...
                # self.log.debug('Socket %s was removed before it was added '
                #              '(obj = %s)', sock, obj)
                self._register_sockets.pop(sock)
            except KeyError:
                # self.log.debug('Socket %s was removed '
                #              '(obj = %s)', sock, obj)
                self._unregister_sockets.add(sock)

    def __call__(self):
        '''Selector poll loop'''
        avail_buf = array.array('i', [0])
        while self.running:
            with self._socket_map_lock:
                for sock in self._unregister_sockets:
                    self.selector.unregister(sock)
                self._socket_count -= len(self._unregister_sockets)
                self._unregister_sockets.clear()

                for sock, obj_id in self._register_sockets.items():
                    self.selector.register(sock, selectors.EVENT_READ,
                                           data=obj_id)
                self._socket_count += len(self._register_sockets)
                self._register_sockets.clear()

            if self._socket_count == 0:
                if self._register_event.wait(timeout=0.1):
                    self._register_event.clear()
                continue

            events = self.selector.select(timeout=0.1)
            with self._socket_map_lock:
                if self._unregister_sockets:
                    # some sockets may be affected here; try again
                    continue

                object_and_socket = [(self.objects[key.data], key.fileobj)
                                     for key, mask in events]

            for obj, sock in object_and_socket:
                if sock in self._unregister_sockets:
                    continue

                # TODO: consider thread pool for recv and command_loop
                try:
                    bytes_available = socket_bytes_available(
                        sock, available_buffer=avail_buf)
                    bytes_recv, address = sock.recvfrom(bytes_available)
                except ConnectionResetError as ex:
                    if sock.type == socket.SOCK_DGRAM:
                        # Win32: "On a UDP-datagram socket this error indicates
                        # a previous send operation resulted in an ICMP Port
                        # Unreachable message."
                        #
                        # https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
                        obj.log.debug(
                            "UDP socket indicates previous send failed %s: %s",
                            obj,
                            ex
                        )
                        continue

                    obj.log.error("Removing %s due to %s (%s)", obj, ex, ex.errno)
                    self.remove_socket(sock)
                except OSError as ex:
                    if ex.errno != errno.EAGAIN:
                        # register as a disconnection
                        obj.log.error('Removing %s due to %s (%s)', obj, ex,
                                      ex.errno)
                        self.remove_socket(sock)
                    continue

                try:
                    # Let objects handle disconnection by return value
                    if obj.received(bytes_recv, address) is ca.DISCONNECTED:
                        obj.log.debug('Removing %s = %s after DISCONNECTED '
                                      'return value', sock, obj)
                        self.remove_socket(sock)
                        # TODO: consider adding specific DISCONNECTED instead
                        # of b'' sent to disconnected sockets
                except Exception as ex:
                    if sock.type == socket.SOCK_DGRAM:
                        obj.log.exception(
                            'UDP socket for %s failed on receipt of '
                            'new data: %s', obj, ex)
                    else:
                        obj.log.exception(
                            'Removing %s due to an internal error on receipt of '
                            'new data: %s', obj, ex)
                        self.remove_socket(sock)


[docs]class SharedBroadcaster: def __init__(self, *, registration_retry_time=10.0): ''' A broadcaster client which can be shared among multiple Contexts Parameters ---------- registration_retry_time : float, optional The time, in seconds, between attempts made to register with the repeater. Default is 10. ''' self.environ = ca.get_environment_variables() self.ca_server_port = self.environ['EPICS_CA_SERVER_PORT'] self.udp_sock = ca.bcast_socket() # Must bind or getsocketname() will raise on Windows. # See https://github.com/caproto/caproto/issues/514. self.udp_sock.bind(('', 0)) self._search_lock = threading.RLock() self._retry_unanswered_searches_thread = None # This Event ensures that we send a registration request before our # first search request. self._searching_enabled = threading.Event() # This Event lets us nudge the search thread when the user asks for new # PVs (via Context.get_pvs). self._search_now = threading.Event() self.search_results = {} # map name to (time, address) # map search id (cid) to [name, queue, last_search_time, retirement_deadline] self.unanswered_searches = {} self.server_protocol_versions = {} # map address to protocol version self._id_counter = ThreadsafeCounter( initial_value=random.randint(0, MAX_ID), dont_clash_with=self.unanswered_searches, ) self.listeners = weakref.WeakSet() self.broadcaster = ca.Broadcaster(our_role=ca.CLIENT) self.broadcaster.client_address = safe_getsockname(self.udp_sock) self.log = logging.LoggerAdapter( self.broadcaster.log, {'role': 'CLIENT'}) self.search_log = logging.LoggerAdapter( logging.getLogger('caproto.bcast.search'), {'role': 'CLIENT'}) self.command_bundle_queue = Queue() self.last_beacon = {} self.last_beacon_interval = {} # an event to tear down and clean up the broadcaster self._close_event = threading.Event() self.selector = SelectorThread(parent=self) self.selector.add_socket(self.udp_sock, self) self.selector.start() self._command_thread = threading.Thread(target=self.command_loop, daemon=True, name='command') self._command_thread.start() self._check_for_unresponsive_servers_thread = threading.Thread( target=self._check_for_unresponsive_servers, daemon=True, name='check_for_unresponsive_servers') self._check_for_unresponsive_servers_thread.start() self._registration_retry_time = registration_retry_time self._registration_last_sent = 0 try: # Always attempt registration on initialization, but allow failures self._register() except Exception: self.log.exception('Broadcaster registration failed on init') def _should_attempt_registration(self): 'Whether or not a registration attempt should be tried' if self.udp_sock is None: # This broadcaster does not currently support being revived from # a disconnected state. Do not attempt registration if the # __init__-defined socket has been removed. return False if (self.broadcaster.registered or self._registration_retry_time is None): return False since_last_attempt = time.monotonic() - self._registration_last_sent if since_last_attempt < self._registration_retry_time: return False return True def _register(self): 'Send a registration request to the repeater' self._registration_last_sent = time.monotonic() commands = [self.broadcaster.register()] bytes_to_send = self.broadcaster.send(*commands) addr = (ca.get_local_address(), self.environ['EPICS_CA_REPEATER_PORT']) tags = { 'role': 'CLIENT', 'our_address': self.broadcaster.client_address, 'direction': '--->>>', 'their_address': addr, } self.broadcaster.log.debug( '%d commands %dB', len(commands), len(bytes_to_send), extra=tags) try: self.udp_sock.sendto(bytes_to_send, addr) except (OSError, AttributeError) as ex: host, specified_port = addr self.log.exception('%s while sending %d bytes to %s:%d', ex, len(bytes_to_send), host, specified_port) self._searching_enabled.set()
[docs] def new_id(self): return self._id_counter()
[docs] def add_listener(self, listener): with self._search_lock: if self._retry_unanswered_searches_thread is None: self._retry_unanswered_searches_thread = threading.Thread( target=self._retry_unanswered_searches, daemon=True, name='retry') self._retry_unanswered_searches_thread.start() self.listeners.add(listener)
[docs] def remove_listener(self, listener): try: self.listeners.remove(listener) except KeyError: pass if not self.listeners: self.disconnect()
[docs] def disconnect(self, *, wait=True): if self.udp_sock is not None: self.selector.remove_socket(self.udp_sock) self.udp_sock.close() self.udp_sock = None self._close_event.set() with self._search_lock: self.search_results.clear() self._registration_last_sent = 0 self._searching_enabled.clear() self.broadcaster.disconnect() self.selector.stop() if wait: self._command_thread.join() self.selector.thread.join() self._retry_unanswered_searches_thread.join()
[docs] def send(self, *commands): """ Process a command and transport it over the UDP socket. """ bytes_to_send = self.broadcaster.send(*commands) tags = {'role': 'CLIENT', 'our_address': self.broadcaster.client_address, 'direction': '--->>>'} for host_tuple in ca.get_client_address_list(): tags['their_address'] = host_tuple self.broadcaster.log.debug( '%d commands %dB', len(commands), len(bytes_to_send), extra=tags) sock = self.udp_sock if sock is None: return try: sock.sendto(bytes_to_send, host_tuple) except OSError as ex: host, specified_port = host_tuple raise CaprotoNetworkError( f'{ex} while sending {len(bytes_to_send)} bytes to ' f'{host}:{specified_port}') from ex
[docs] def get_cached_search_result(self, name, *, threshold=STALE_SEARCH_EXPIRATION): 'Returns address if found, raises KeyError if missing or stale.' with self._search_lock: address, timestamp = self.search_results[name] # this block of code is only to re-fresh the time found on # any PVs. If we can find any context which has any circuit which # has any channel talking to this PV name then it is not stale so # re-up the timestamp to now. if time.monotonic() - timestamp > threshold: # TODO this is very inefficient for context in self.listeners: for cm in context.circuit_managers.values(): if cm.connected and name in cm.all_created_pvnames: # A valid connection exists in one of our clients, so # ignore the stale result status with self._search_lock: self.search_results[name] = (address, time.monotonic()) # TODO verify that addr matches address return address with self._search_lock: # Clean up expired result. self.search_results.pop(name, None) raise CaprotoKeyError(f'{name!r}: stale search result') return address
[docs] def search(self, results_queue, names, *, timeout=2): """ Search for PV names. The ``results_queue`` will receive ``(address, names)`` (the address of a server and a list of name(s) that it has) when results are received. If a cached result is already known, it will be put immediately into ``results_queue`` from this thread during this method's execution. If not, a SearchRequest will be sent from another thread. If necessary, the request will be re-sent periodically. When a matching response is received (by yet another thread) ``(address, names)`` will be put into the ``results_queue``. """ if self._should_attempt_registration(): self._register() new_id = self.new_id unanswered_searches = self.unanswered_searches with self._search_lock: # We have have already searched for these names recently. # Filter `pv_names` down to a subset, `needs_search`. needs_search = [] use_cached_search = defaultdict(list) for name in names: try: address = self.get_cached_search_result(name) except KeyError: needs_search.append(name) else: use_cached_search[address].append(name) for address, names in use_cached_search.items(): results_queue.put((address, names)) # Generate search_ids and stash them on Context state so they can # be used to match SearchResponses with SearchRequests. search_ids = [] # Search requests that are past their retirement deadline with no # results will be searched for less frequently. retirement_deadline = time.monotonic() + common.SEARCH_RETIREMENT_AGE for name in needs_search: search_id = new_id() search_ids.append(search_id) # The value is a list because we mutate it to update the # retirement deadline sometimes. unanswered_searches[search_id] = [name, results_queue, 0, retirement_deadline] self._search_now.set()
[docs] def cancel(self, *names): """ Cancel searches for these names. Parameters ---------- *names : strings any number of PV names Any PV instances that were awaiting these results will be stuck until :meth:`get_pvs` is called again. """ with self._search_lock: for search_id, item in list(self.unanswered_searches.items()): if item[0] in names: del self.unanswered_searches[search_id]
[docs] def search_now(self): """ Force the Broadcaster to reissue all unanswered search requests now. Left to its own devices, the Broadcaster will do this at regular intervals automatically. This method is intended primarily for debugging and should not be needed in normal use. """ self._search_now.set()
[docs] def received(self, bytes_recv, address): "Receive and process and next command broadcasted over UDP." if bytes_recv: commands = self.broadcaster.recv(bytes_recv, address) if commands: self.command_bundle_queue.put(commands) return 0
[docs] def command_loop(self): # Receive commands in 'bundles' (corresponding to the contents of one # UDP datagram). Match SearchResponses to their SearchRequests, and # put (address, (name1, name2, name3, ...)) into a queue. The receiving # end of that queue is held by Context._process_search_results. # Save doing a 'self' lookup in the inner loop. search_results = self.search_results server_protocol_versions = self.server_protocol_versions unanswered_searches = self.unanswered_searches queues = defaultdict(list) results_by_cid = deque(maxlen=1000) self.log.debug('Broadcaster command loop is running.') while not self._close_event.is_set(): try: commands = self.command_bundle_queue.get(timeout=0.5) except Empty: # By restarting the loop, we will first check that we are not # supposed to shut down the thread before we go back to # waiting on the queue again. continue try: self.broadcaster.process_commands(commands) except ca.CaprotoError as ex: self.log.warning('Broadcaster command error', exc_info=ex) continue queues.clear() now = time.monotonic() tags = {'role': 'CLIENT', 'our_address': self.broadcaster.client_address, 'direction': '<<<---'} for command in commands: if isinstance(command, ca.Beacon): now = time.monotonic() address = (command.address, command.server_port) tags['their_address'] = address if address not in self.last_beacon: # We made a new friend! self.broadcaster.log.info("Watching Beacons from %s:%d", *address, extra=tags) self._new_server_found() else: interval = now - self.last_beacon[address] if interval < self.last_beacon_interval.get(address, 0) / 4: # Beacons are arriving *faster*? The server at this # address may have restarted. self.broadcaster.log.info( "Beacon anomaly: %s:%d may have restarted.", *address, extra=tags) self._new_server_found() self.last_beacon_interval[address] = interval self.last_beacon[address] = now elif isinstance(command, ca.SearchResponse): cid = command.cid try: with self._search_lock: name, queue, *_ = unanswered_searches.pop(cid) except KeyError: # This is a redundant response, which the EPICS # spec tells us to ignore. (The first responder # to a given request wins.) try: _, name = next(r for r in results_by_cid if r[0] == cid) except StopIteration: continue else: with self._search_lock: if name in search_results: accepted_address, _ = search_results[name] new_address = ca.extract_address(command) if new_address != accepted_address: search_logger.warning( "PV %s with cid %d found on multiple " "servers. Accepted address is %s:%d. " "Also found on %s:%d", name, cid, *accepted_address, *new_address, extra={'pv': name, 'their_address': accepted_address, 'our_address': self.broadcaster.client_address}) else: results_by_cid.append((cid, name)) address = ca.extract_address(command) queues[queue].append(name) # Cache this to save time on future searches. # (Entries expire after STALE_SEARCH_EXPIRATION.) with self._search_lock: search_results[name] = (address, now) server_protocol_versions[address] = command.version # Send the search results to the Contexts that asked for # them. This is probably more general than is has to be but # I'm playing it safe for now. if queues: for queue, names in queues.items(): queue.put((address, names)) self.log.debug('Broadcaster command loop has exited.')
def _new_server_found(self): # Bring all the unanswered seraches out of retirement # to see if we have a new match. retirement_deadline = time.monotonic() + common.SEARCH_RETIREMENT_AGE with self._search_lock: for item in self.unanswered_searches.values(): # give new age-out deadline item[-1] = retirement_deadline
[docs] def time_since_last_heard(self): """ Map each known server address to seconds since its last message. The time is reset to 0 whenever we receive a TCP message related to user activity *or* a Beacon. Servers are expected to send Beacons at regular intervals. If we do not receive either a Beacon or TCP message, we initiate an Echo over TCP, to which the server is expected to promptly respond. Therefore, the time reported here should not much exceed ``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that environment variable) if the server is healthy. If the server fails to send a Beacon on schedule *and* fails to reply to an Echo, the server is assumed dead. A warning is issued, and all PVs are disconnected to initiate a reconnection attempt. """ return {address: time.monotonic() - t for address, t in self._last_heard.items()}
def _check_for_unresponsive_servers(self): self.log.debug('Broadcaster check for unresponsive servers loop is running.') MARGIN = 1 # extra time (seconds) allowed between Beacons checking = dict() # map address to deadline for check to resolve servers = defaultdict(weakref.WeakSet) # map address to VirtualCircuitManagers last_heard = dict() # map address to time of last response self._last_heard = last_heard # Make locals to save getattr lookups in the loop. last_beacon = self.last_beacon listeners = self.listeners while not self._close_event.is_set(): servers.clear() last_heard.clear() now = time.monotonic() # We are interested in identifying servers that we have not heard # from since some time cutoff in the past. cutoff = now - (self.environ['EPICS_CA_CONN_TMO'] + MARGIN) # Map each server address to VirtualCircuitManagers connected to # that address, across all Contexts ("listeners"). for listener in listeners: for (address, _), circuit_manager in listener.circuit_managers.items(): servers[address].add(circuit_manager) # When is the last time we heard from each server, either via a # Beacon or from TCP packets related to user activity or any # circuit? for address, circuit_managers in servers.items(): last_tcp_receipt = (cm.last_tcp_receipt for cm in circuit_managers) last_heard[address] = max((last_beacon.get(address, 0), *last_tcp_receipt)) # If is has been too long --- and if we aren't already checking # on this server --- try to prompt a response over TCP by # sending an EchoRequest. if last_heard[address] < cutoff and address not in checking: # Record that we are checking on this address and set a # deadline for a response. checking[address] = now + RESPONSIVENESS_TIMEOUT tags = {'role': 'CLIENT', 'their_address': address, 'our_address': self.broadcaster.client_address, 'direction': '--->>>'} self.broadcaster.log.debug( "Missed Beacons from %s:%d. Sending EchoRequest to " "check that server is responsive.", *address, extra=tags) # Send on all circuits. One might be less backlogged # with queued commands than the others and thus able to # respond faster. In the majority of cases there will only # be one circuit per server anyway, so this is a minor # distinction. for circuit_manager in circuit_managers: try: circuit_manager.send(ca.EchoRequest()) except Exception: # Send failed. Server is likely dead, but we'll # catch that shortly; no need to handle it # specially here. pass # Check to see if any of our ongoing checks have resolved or # failed to resolve within the allowed response window. for address, deadline in list(checking.items()): if last_heard[address] > cutoff: # It's alive! checking.pop(address) elif deadline < now: # No circuit connected to the server at this address has # sent Beacons or responded to the EchoRequest. We assume # it is unresponsive. The EPICS specification says the # behavior is undefined at this point. We choose to # disconnect all circuits from that server so that PVs can # attempt to connect to a new server, such as a failover # backup. for circuit_manager in servers[address]: if circuit_manager.connected: circuit_manager.log.warning( "Server at %s:%d is unresponsive. " "Disconnecting circuit manager %r. PVs will " "automatically begin attempting to reconnect " "to a responsive server.", *address, circuit_manager) circuit_manager._disconnected() checking.pop(address) # else: # # We are still waiting to give the server time to respond # # to the EchoRequest. time.sleep(0.5) self.log.debug('Broadcaster check for unresponsive servers loop has exited.') def _retry_unanswered_searches(self): """ Periodically (re-)send a SearchRequest for all unanswered searches. """ # Each time new searches are added, the self._search_now Event is set, # and we reissue *all* unanswered searches. # # We then frequently retry the unanswered searches that are younger # than SEARCH_RETIREMENT_AGE, backing off from an interval of # MIN_RETRY_SEARCHES_INTERVAL to MAX_RETRY_SEARCHES_INTERVAL. The # interval is reset to MIN_RETRY_SEARCHES_INTERVAL each time new # searches are added. # # For the searches older than SEARCH_RETIREMENT_AGE, we adopt a slower # period to minimize network traffic. We only resend every # RETRY_RETIRED_SEARCHES_INTERVAL or, again, whenever new searches # are added. self.log.debug('Broadcaster search-retry thread has started.') time_to_check_on_retirees = time.monotonic() + common.RETRY_RETIRED_SEARCHES_INTERVAL interval = common.MIN_RETRY_SEARCHES_INTERVAL while not self._close_event.is_set(): try: self._searching_enabled.wait(0.5) except TimeoutError: # Here we go check on self._close_event before waiting again. continue t = time.monotonic() # filter to just things that need to go out def _construct_search_requests(items): for search_id, it in items: yield ca.SearchRequest(it[0], search_id, ca.DEFAULT_PROTOCOL_VERSION) # reset the last time this was sent it[-2] = t with self._search_lock: if t >= time_to_check_on_retirees: items = list(self.unanswered_searches.items()) time_to_check_on_retirees += common.RETRY_RETIRED_SEARCHES_INTERVAL else: # Skip over searches that haven't gotten any results in # SEARCH_RETIREMENT_AGE. items = list((search_id, it) for search_id, it in self.unanswered_searches.items() if (it[-1] > t)) # only send requests who we last sent at least interval in the past resend_deadline = t - interval items = [(sid, it) for sid, it in items if it[-2] < resend_deadline] requests = _construct_search_requests(items) if not self._searching_enabled.is_set(): continue if items: self.search_log.debug('Sending %d SearchRequests', len(items)) version_req = ca.VersionRequest(0, ca.DEFAULT_PROTOCOL_VERSION) for batch in batch_requests(requests, SEARCH_MAX_DATAGRAM_BYTES - len(version_req)): self.send(version_req, *batch) wait_time = max(0, interval - (time.monotonic() - t)) # Double the interval for the next loop. interval = min(2 * interval, common.MAX_RETRY_SEARCHES_INTERVAL) if self._search_now.wait(wait_time): # New searches have been requested. Reset the interval between # subseqent searches and force a check on the "retirees". time_to_check_on_retirees = t interval = common.MIN_RETRY_SEARCHES_INTERVAL self._search_now.clear() self.log.debug('Broadcaster search-retry thread has exited.') def __del__(self): try: self.disconnect() self.selector = None except AttributeError: pass
[docs]class Context: """ Encapsulates the state and connections of a client Parameters ---------- broadcaster : SharedBroadcaster, optional If None is specified, a fresh one is instantiated. timeout : number or None, optional Number of seconds before a CaprotoTimeoutError is raised. This default can be overridden at the PV level or for any given operation. If unset, the default is 2 seconds. If None, never timeout. A global timeout can be specified via an environment variable ``CAPROTO_DEFAULT_TIMEOUT``. host_name : string, optional uses value of ``socket.gethostname()`` by default client_name : string, optional uses value of ``getpass.getuser()`` by default max_workers : integer, optional Number of worker threaders *per VirtualCircuit* for executing user callbacks. Default is 1. For any number of workers, workers will receive updates in the order which they are received from the server. That is, work on each update will *begin* in sequential order. Work-scheduling internal to the user callback is outside caproto's control. If the number of workers is set to greater than 1, the work on each update may not *finish* in a deterministic order. For example, if workers are writing lines into a file, the only way to guarantee that the lines are ordered properly is to use only one worker. If ordering matters for your application, think carefully before increasing this value from 1. """ def __init__(self, broadcaster=None, *, timeout=common.GLOBAL_DEFAULT_TIMEOUT, host_name=None, client_name=None, max_workers=1): if broadcaster is None: broadcaster = SharedBroadcaster() self.broadcaster = broadcaster self.timeout = timeout if host_name is None: host_name = socket.gethostname() self.host_name = host_name if client_name is None: client_name = getpass.getuser() self.max_workers = max_workers self.client_name = client_name self.log = logging.LoggerAdapter( logging.getLogger('caproto.ctx'), {'role': 'CLIENT'}) self.pv_cache_lock = threading.RLock() self.circuit_managers = {} # keyed on ((host, port), priority) self._lock_during_get_circuit_manager = threading.RLock() self.pvs = {} # (name, priority) -> pv # name -> set of pvs --- with varied priority self.pvs_needing_circuits = defaultdict(set) self.broadcaster.add_listener(self) self._search_results_queue = Queue() # an event to close and clean up the whole context self._close_event = threading.Event() self.subscriptions_lock = threading.RLock() self.subscriptions_to_activate = defaultdict(set) self.activate_subscriptions_now = threading.Event() self._process_search_results_thread = threading.Thread( target=self._process_search_results_loop, daemon=True, name='search') self._process_search_results_thread.start() self._activate_subscriptions_thread = threading.Thread( target=self._activate_subscriptions, daemon=True, name='activate_subscriptions') self._activate_subscriptions_thread.start() self.selector = SelectorThread(parent=self) self.selector.start() self._user_disconnected = False def __repr__(self): return (f"<Context " f"searches_pending={len(self.broadcaster.unanswered_searches)} " f"circuits={len(self.circuit_managers)} " f"pvs={len(self.pvs)} " f"idle={len([1 for pv in self.pvs.values() if pv._idle])}>") def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.disconnect(wait=True)
[docs] def get_pvs(self, *names, priority=0, connection_state_callback=None, access_rights_callback=None, timeout=common.CONTEXT_DEFAULT_TIMEOUT): """ Return a list of PV objects. These objects may not be connected at first. Channel creation occurs on a background thread. PVs are uniquely defined by their name and priority. If a PV with the same name and priority is requested twice, the same (cached) object is returned. Any callbacks included here are added to added alongside any existing ones. Parameters ---------- *names : strings any number of PV names priority : integer Used by the server to triage subscription responses when under high load. 0 is lowest; 99 is highest. connection_state_callback : callable Expected signature: ``f(pv, state)`` where ``pv`` is the instance of ``PV`` whose state has changed and ``state`` is a string access_rights_callback : callable Expected signature: ``f(pv, access_rights)`` where ``pv`` is the instance of ``PV`` whose state has changed and ``access_rights`` is a member of the caproto ``AccessRights`` enum timeout : number or None, optional Number of seconds before a CaprotoTimeoutError is raised. This default can be overridden for any specific operation. By default, fall back to the default timeout set by the Context. If None, never timeout. """ if self._user_disconnected: raise ContextDisconnectedError("This Context is no longer usable.") pvs = [] # list of all PV objects to return names_to_search = [] # subset of names that we need to search for for name in names: with self.pv_cache_lock: try: pv = self.pvs[(name, priority)] except KeyError: pv = PV(name, priority, self, timeout) names_to_search.append(name) self.pvs[(name, priority)] = pv self.pvs_needing_circuits[name].add(pv) if connection_state_callback is not None: pv.connection_state_callback.add_callback( connection_state_callback, run=True) if access_rights_callback is not None: pv.access_rights_callback.add_callback( access_rights_callback, run=True) pvs.append(pv) # TODO: potential bug? # if callback is quick, is there a chance downstream listeners may # never receive notification? # Ask the Broadcaster to search for every PV for which we do not # already have an instance. It might already have a cached search # result, but that is the concern of broadcaster.search. if names_to_search: self.broadcaster.search(self._search_results_queue, names_to_search) return pvs
def reconnect(self, keys): # We will reuse the same PV object but use a new cid. names = [] pvs = [] for key in keys: with self.pv_cache_lock: pv = self.pvs[key] pvs.append(pv) name, _ = key names.append(name) # If there is a cached search result for this name, expire it. with self.broadcaster._search_lock: self.broadcaster.search_results.pop(name, None) with self.pv_cache_lock: self.pvs_needing_circuits[name].add(pv) self.broadcaster.search(self._search_results_queue, names) def _process_search_results_loop(self): # Receive (address, (name1, name2, ...)). The sending side of this # queue is held by SharedBroadcaster.command_loop. self.log.debug('Context search-results processing loop has ' 'started.') while not self._close_event.is_set(): try: address, names = self._search_results_queue.get(timeout=0.5) except Empty: # By restarting the loop, we will first check that we are not # supposed to shut down the thread before we go back to # waiting on the queue again. continue channels_grouped_by_circuit = defaultdict(list) # Assign each PV a VirtualCircuitManager for managing a socket # and tracking circuit state, as well as a ClientChannel for # tracking channel state. for name in names: search_logger.debug('Connecting %s on circuit with %s:%d', name, *address, extra={'pv': name, 'their_address': address, 'our_address': self.broadcaster.broadcaster.client_address, 'direction': '--->>>', 'role': 'CLIENT'}) # There could be multiple PVs with the same name and # different priority. That is what we are looping over # here. There could also be NO PVs with this name that need # a circuit, because we could be receiving a duplicate # search response (which we are supposed to ignore). with self.pv_cache_lock: pvs = self.pvs_needing_circuits.pop(name, set()) for pv in pvs: # Get (make if necessary) a VirtualCircuitManager. This # is where TCP socket creation happens. cm = self.get_circuit_manager(address, pv.priority) circuit = cm.circuit pv.circuit_manager = cm # TODO: NOTE: we are not following the suggestion to # use the same cid as in the search. This simplifies # things between the broadcaster and Context. cid = cm.circuit.new_channel_id() chan = ca.ClientChannel(name, circuit, cid=cid) cm.channels[cid] = chan cm.pvs[cid] = pv channels_grouped_by_circuit[cm].append(chan) pv.circuit_ready.set() # Initiate channel creation with the server. for cm, channels in channels_grouped_by_circuit.items(): commands = [chan.create() for chan in channels] try: cm.send(*commands) except Exception: if cm.dead.is_set(): self.log.debug("Circuit died while we were trying " "to create the channel. We will " "keep attempting this until it " "works.") # When the Context creates a new circuit, we will end # up here again. No big deal. continue raise self.log.debug('Context search-results processing thread has exited.') def get_circuit_manager(self, address, priority): """ Return a VirtualCircuitManager for this address, priority. (It manages a caproto.VirtualCircuit and a TCP socket.) Make a new one if necessary. """ with self._lock_during_get_circuit_manager: cm = self.circuit_managers.get((address, priority), None) if cm is None or cm.dead.is_set(): version = self.broadcaster.server_protocol_versions[address] circuit = ca.VirtualCircuit( our_role=ca.CLIENT, address=address, priority=priority, protocol_version=version) cm = VirtualCircuitManager(self, circuit, self.selector) self.circuit_managers[(address, priority)] = cm return cm def _activate_subscriptions(self): while not self._close_event.is_set(): t = time.monotonic() with self.subscriptions_lock: items = list(self.subscriptions_to_activate.items()) self.subscriptions_to_activate.clear() for cm, subs in items: def requests(): "Yield EventAddRequest commands." for sub in subs: command = sub.compose_command() # compose_command() returns None if this # Subscription is inactive (meaning there are no # user callbacks attached). It will send an # EventAddRequest on its own if/when the user does # add any callbacks, so we can skip it here. if command is not None: yield command for batch in batch_requests(requests(), common.EVENT_ADD_BATCH_MAX_BYTES): try: cm.send(*batch) except Exception: if cm.dead.is_set(): self.log.debug("Circuit died while we were " "trying to activate " "subscriptions. We will " "keep attempting this until it " "works.") # When the Context creates a new circuit, we will # end up here again. No big deal. break wait_time = max(0, (common.RESTART_SUBS_PERIOD - (time.monotonic() - t))) self.activate_subscriptions_now.wait(wait_time) self.activate_subscriptions_now.clear() self.log.debug('Context restart-subscriptions thread exiting') def disconnect(self, *, wait=True): self._user_disconnected = True try: self._close_event.set() # disconnect any circuits we have circuits = list(self.circuit_managers.values()) total_circuits = len(circuits) disconnected = False for idx, circuit in enumerate(circuits, 1): if circuit.connected: self.log.debug('Disconnecting circuit %d/%d: %s', idx, total_circuits, circuit) circuit.disconnect() disconnected = True if disconnected: self.log.debug('All circuits disconnected') finally: # Remove from Broadcaster. self.broadcaster.remove_listener(self) # clear any state about circuits and search results self.log.debug('Clearing circuit managers') self.circuit_managers.clear() self.log.debug("Stopping SelectorThread of the context") self.selector.stop() if wait: self._process_search_results_thread.join() self._activate_subscriptions_thread.join() self.selector.thread.join() self.log.debug('Context disconnection complete') def __del__(self): try: self.disconnect(wait=False) except Exception: ... finally: self.selector = None self.broadcaster = None self.circuit_managers = None
[docs]class VirtualCircuitManager: """ Encapsulates a VirtualCircuit, a TCP socket, and additional state This object should never be instantiated directly by user code. It is used internally by the Context. Its methods may be touched by user code, but this is rarely necessary. """ __slots__ = ('context', 'circuit', 'channels', 'ioids', '_ioid_counter', 'subscriptions', '_ready', 'log', 'socket', 'selector', 'pvs', 'all_created_pvnames', 'dead', 'process_queue', 'processing', '_subscriptionid_counter', 'user_callback_executor', 'last_tcp_receipt', '__weakref__', '_tags') def __init__(self, context, circuit, selector, timeout=common.GLOBAL_DEFAULT_TIMEOUT): self.context = context self.circuit = circuit # a caproto.VirtualCircuit self.log = circuit.log self.channels = {} # map cid to Channel self.pvs = {} # map cid to PV self.ioids = {} # map ioid to Channel and info dict self.subscriptions = {} # map subscriptionid to Subscription self.socket = None self.selector = selector self.user_callback_executor = concurrent.futures.ThreadPoolExecutor( max_workers=self.context.max_workers, thread_name_prefix='user-callback-executor') self.last_tcp_receipt = None # keep track of all PV names that are successfully connected to within # this circuit. This is to be cleared upon disconnection: self.all_created_pvnames = [] self.dead = threading.Event() self._ioid_counter = ThreadsafeCounter() self._subscriptionid_counter = ThreadsafeCounter() self._ready = threading.Event() # Connect. if self.circuit.states[ca.SERVER] is ca.IDLE: self.socket = socket.create_connection(self.circuit.address) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.circuit.our_address = self.socket.getsockname() # This dict is passed to the loggers. self._tags = {'their_address': self.circuit.address, 'our_address': self.circuit.our_address, 'direction': '<<<---', 'role': repr(self.circuit.our_role)} self.selector.add_socket(self.socket, self) self.send(ca.VersionRequest(self.circuit.priority, ca.DEFAULT_PROTOCOL_VERSION), ca.HostNameRequest(self.context.host_name), ca.ClientNameRequest(self.context.client_name), extra=self._tags) else: raise CaprotoRuntimeError("Cannot connect. States are {} " "".format(self.circuit.states)) # Old versions of the protocol do not send a VersionResponse at TCP # connection time, so set this Event manually rather than waiting for # it to be set by receipt of a VersionResponse. if self.context.broadcaster.server_protocol_versions[self.circuit.address] < 12: self._ready.set() ready = self._ready.wait(timeout=timeout) if not ready: host, port = self.circuit.address raise CaprotoTimeoutError(f"Circuit with server at {host}:{port} " f"did not connect within " f"{float(timeout):.3}-second timeout.") def __repr__(self): return (f"<VirtualCircuitManager circuit={self.circuit} " f"pvs={len(self.pvs)} ioids={len(self.ioids)} " f"subscriptions={len(self.subscriptions)}>") @property def connected(self): return self.circuit.states[ca.CLIENT] is ca.CONNECTED
[docs] def send(self, *commands, extra=None): # Turn the crank: inform the VirtualCircuit that these commands will # be send, and convert them to buffers. sock = self.socket if sock is not None: buffers_to_send = self.circuit.send(*commands, extra=extra) sock.sendall(b"".join(buffers_to_send))
[docs] def received(self, bytes_recv, address): """Receive and process and next command from the virtual circuit. This will be run on the recv thread""" self.last_tcp_receipt = time.monotonic() commands, num_bytes_needed = self.circuit.recv(bytes_recv) for c in commands: self._process_command(c) if not bytes_recv: # Tell the selector to remove our socket return ca.DISCONNECTED return num_bytes_needed
[docs] def events_off(self): """ Suspend updates to all subscriptions on this circuit. This may be useful if the server produces updates faster than the client can processs them. """ self.send(ca.EventsOffRequest())
[docs] def events_on(self): """ Reactive updates to all subscriptions on this circuit. """ self.send(ca.EventsOnRequest())
def _process_command(self, command): try: self.circuit.process_command(command) except ca.CaprotoError as ex: if hasattr(ex, 'channel'): channel = ex.channel self.log.warning('Invalid command %s for Channel %s in state %s', command, channel, channel.states, exc_info=ex) # channel exceptions are not fatal return else: self.log.error('Invalid command %s for VirtualCircuit %s in ' 'state %s', command, self, self.circuit.states, exc_info=ex) # circuit exceptions are fatal; exit the loop self.disconnect() return tags = self._tags if command is ca.DISCONNECTED: self._disconnected() elif isinstance(command, (ca.VersionResponse,)): assert self.connected # double check that the state machine agrees self._ready.set() elif isinstance(command, (ca.ReadNotifyResponse, ca.ReadResponse, ca.WriteNotifyResponse)): ioid_info = self.ioids.pop(command.ioid) deadline = ioid_info['deadline'] pv = ioid_info['pv'] tags = tags.copy() tags['pv'] = pv.name if deadline is not None and time.monotonic() > deadline: self.log.warning("Ignoring late response with ioid=%d regarding " "PV named %s because " "it arrived %.3f seconds after the deadline " "specified by the timeout.", command.ioid, pv.name, time.monotonic() - deadline) return event = ioid_info.get('event') if event is not None: # If PV.read() or PV.write() are waiting on this response, # they hold a reference to ioid_info. We will use that to # provide the response to them and then set the Event that they # are waiting on. ioid_info['response'] = command event.set() callback = ioid_info.get('callback') if callback is not None: try: self.user_callback_executor.submit(callback, command) except RuntimeError: if self.dead.is_set(): # if we are trying to process updates while # shutting down the submit will fail. In that # case we should drop the exception on the # floor and move on return # otherwise raise and let someone else deal with # the mess raise elif isinstance(command, ca.EventAddResponse): try: sub = self.subscriptions[command.subscriptionid] except KeyError: # This subscription has been removed. We assume that this # response was in flight before the server processed our # unsubscription. pass else: # This method submits jobs to the Contexts's # ThreadPoolExecutor for user callbacks. sub.process(command) tags = tags.copy() tags['pv'] = sub.pv.name elif isinstance(command, ca.AccessRightsResponse): pv = self.pvs[command.cid] pv.access_rights_changed(command.access_rights) tags = tags.copy() tags['pv'] = pv.name elif isinstance(command, ca.EventCancelResponse): # TODO Any way to add the pv name to tags here? ... elif isinstance(command, ca.CreateChanResponse): pv = self.pvs[command.cid] chan = self.channels[command.cid] self.all_created_pvnames.append(pv.name) with pv.component_lock: pv.channel = chan pv.channel_ready.set() pv.connection_state_changed('connected', chan) tags = tags.copy() tags['pv'] = pv.name elif isinstance(command, (ca.ServerDisconnResponse, ca.ClearChannelResponse)): pv = self.pvs[command.cid] pv.connection_state_changed('disconnected', None) tags = tags.copy() tags['pv'] = pv.name # NOTE: pv remains valid until server goes down elif isinstance(command, ca.EchoResponse): # The important effect here is that it will have updated # self.last_tcp_receipt when the bytes flowed through # self.received. ... if isinstance(command, ca.Message): tags['bytesize'] = len(command) self.log.debug("%r", command, extra=tags) def _disconnected(self, *, reconnect=True): # Ensure that this method is idempotent. if self.dead.is_set(): return tags = {'their_address': self.circuit.address} self.log.debug('Virtual circuit with address %s:%d has disconnected.', *self.circuit.address, extra=tags) # Update circuit state. This will be reflected on all PVs, which # continue to hold a reference to this disconnected circuit. self.circuit.disconnect() for pv in self.pvs.values(): pv.channel_ready.clear() pv.circuit_ready.clear() self.dead.set() for ioid_info in self.ioids.values(): # Un-block any calls to PV.read() or PV.write() that are waiting on # responses that we now know will never arrive. They will check on # circuit health and raise appropriately. event = ioid_info.get('event') if event is not None: event.set() with self.context.broadcaster._search_lock: for n in self.all_created_pvnames: self.context.broadcaster.search_results.pop(n, None) self.all_created_pvnames.clear() for pv in self.pvs.values(): pv.connection_state_changed('disconnected', None) # Remove VirtualCircuitManager from Context. # This will cause all future calls to Context.get_circuit_manager() # to create a fresh VirtualCiruit and VirtualCircuitManager. self.context.circuit_managers.pop(self.circuit.address, None) # Clean up the socket if it has not yet been cleared: sock, self.socket = self.socket, None if sock is not None: self.selector.remove_socket(sock) try: sock.shutdown(socket.SHUT_WR) except OSError: pass sock.close() tags = {'their_address': self.circuit.address} if reconnect: # Kick off attempt to reconnect all PVs via fresh circuit(s). self.log.debug('Kicking off reconnection attempts for %d PVs ' 'disconnected from %s:%d....', len(self.channels), *self.circuit.address, extra=tags) self.context.reconnect(((chan.name, chan.circuit.priority) for chan in self.channels.values())) else: self.log.debug('Not attempting reconnection', extra=tags)
[docs] def disconnect(self): self._disconnected(reconnect=False) self.log.debug("Shutting down ThreadPoolExecutor for user callbacks", extra={'their_address': self.circuit.address}) self.user_callback_executor.shutdown() if self.socket is None: return self.log.debug('Circuit manager disconnected by user')
def __del__(self): try: self._disconnected(reconnect=False) except AttributeError: pass
[docs]class PV: """ Represents one PV, specified by a name and priority. This object may exist prior to connection and persists across any subsequent re-connections. This object should never be instantiated directly by user code; rather it should be created by calling the ``get_pvs`` method on a ``Context`` object. """ __slots__ = ('name', 'priority', 'context', '_circuit_manager', '_channel', 'circuit_ready', 'channel_ready', 'access_rights', 'access_rights_callback', 'subscriptions', 'command_bundle_queue', 'component_lock', '_idle', '_in_use', '_usages', 'connection_state_callback', 'log', '_timeout', '__weakref__') def __init__(self, name, priority, context, timeout): """ These must be instantiated by a Context, never directly. """ self.name = name self.priority = priority self.context = context self.access_rights = None # will be overwritten with AccessRights self.log = logging.LoggerAdapter(ch_logger, {'pv': self.name, 'role': 'CLIENT'}) # Use this lock whenever we touch circuit_manager or channel. self.component_lock = threading.RLock() self.circuit_ready = threading.Event() self.channel_ready = threading.Event() self.connection_state_callback = CallbackHandler(self) self.access_rights_callback = CallbackHandler(self) self._timeout = timeout self._circuit_manager = None self._channel = None self.subscriptions = {} self._idle = False self._in_use = threading.Condition() self._usages = 0 @property def timeout(self): """ Effective default timeout. Valid values are: * CONTEXT_DEFAULT_TIMEOUT (fall back to Context.timeout) * a floating-point number * None (never timeout) """ if self._timeout is common.CONTEXT_DEFAULT_TIMEOUT: return self.context.timeout else: return self._timeout @timeout.setter def timeout(self, val): self._timeout = val @property def circuit_manager(self): return self._circuit_manager @circuit_manager.setter def circuit_manager(self, val): with self.component_lock: self._circuit_manager = val @property def channel(self): return self._channel @channel.setter def channel(self, val): with self.component_lock: self._channel = val
[docs] def access_rights_changed(self, rights): self.access_rights = rights self.access_rights_callback.process(self, rights)
[docs] def connection_state_changed(self, state, channel): self.log.info('connection state changed to %s.', state) self.connection_state_callback.process(self, state) if state == 'disconnected': for sub in self.subscriptions.values(): with sub.callback_lock: if sub.callbacks: sub.needs_reactivation = True if state == 'connected': cm = self.circuit_manager ctx = cm.context with ctx.subscriptions_lock: for sub in self.subscriptions.values(): with sub.callback_lock: if sub.needs_reactivation: ctx.subscriptions_to_activate[cm].add(sub) sub.needs_reactivation = False
def __repr__(self): if self._idle: state = "(idle)" elif self.circuit_manager is None or self.circuit_manager.dead.is_set(): state = "(searching....)" else: state = (f"address={self.circuit_manager.circuit.address}, " f"circuit_state=" f"{self.circuit_manager.circuit.states[ca.CLIENT]}") if self.connected: state += f", channel_state={self.channel.states[ca.CLIENT]}" else: state += " (creating...)" return f"<PV name={self.name!r} priority={self.priority} {state}>" @property def connected(self): channel = self.channel if channel is None: return False return channel.states[ca.CLIENT] is ca.CONNECTED
[docs] @ensure_connected def wait_for_connection(self, *, timeout=common.PV_DEFAULT_TIMEOUT): """ Wait for this PV to be connected. Parameters ---------- timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. """ pass
[docs] def go_idle(self): """Request to clear this Channel to reduce load on client and server. A new Channel will be automatically, silently created the next time any method requiring a connection is called. Thus, this saves some memory in exchange for making the next request a bit slower, as it has to redo the handshake with the server first. If there are any subscriptions with callbacks, this request will be ignored. If the PV is in the process of connecting, this request will be ignored. If there are any actions in progress (read, write) this request will be processed when they are complete. """ for sub in self.subscriptions.values(): if sub.callbacks: return with self._in_use: if not self.channel_ready.is_set(): return # Wait until no other methods that employ @self.ensure_connected # are in process. self._in_use.wait_for(lambda: self._usages == 0) # No other threads are using the connection, and we are holding the # self._in_use Condition's lock, so we can safely close the # connection. The next thread to acquire the lock will re-connect # after it acquires the lock. try: self.channel_ready.clear() self.circuit_manager.send(self.channel.clear(), extra={'pv': self.name}) except OSError: # the socket is dead-dead, do nothing ... self._idle = True
[docs] @ensure_connected def read(self, *, wait=True, callback=None, timeout=common.PV_DEFAULT_TIMEOUT, data_type=None, data_count=None, notify=True): """Request a fresh reading. Can do one or both of: - Block while waiting for the response, and return it. - Pass the response to callback, with or without blocking. Parameters ---------- wait : boolean If True (default) block until a matching response is received from the server. Raises CaprotoTimeoutError if that response is not received within the time specified by the `timeout` parameter. callback : callable or None Called with the response as its argument when received. timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. notify: boolean, optional Send a ``ReadNotifyRequest`` instead of a ``ReadRequest``. True by default. """ if timeout is common.PV_DEFAULT_TIMEOUT: timeout = self.timeout cm, chan = self._circuit_manager, self._channel ioid = cm._ioid_counter() command = chan.read(ioid=ioid, data_type=data_type, data_count=data_count, notify=notify) # Stash the ioid to match the response to the request. event = threading.Event() ioid_info = dict(event=event, pv=self, request=command) if callback is not None: ioid_info['callback'] = callback cm.ioids[ioid] = ioid_info deadline = time.monotonic() + timeout if timeout is not None else None ioid_info['deadline'] = deadline cm.send(command, extra={'pv': self.name}) if not wait: return # The circuit_manager will put a reference to the response into # ioid_info and then set event. if not event.wait(timeout=timeout): host, port = cm.circuit.address raise CaprotoTimeoutError( f"Server at {host}:{port} did " f"not respond to attempt to read channel named " f"{self.name!r} within {float(timeout):.3}-second timeout. " f"The ioid of the expected response is {ioid}." ) if cm.dead.is_set(): # This circuit has died sometime during this function call. # The exception raised here will be caught by # @ensure_connected, which will retry the function call a # in hopes of getting a working circuit until our `timeout` has # been used up. raise DeadCircuitError() return ioid_info['response']
[docs] @ensure_connected def write(self, data, *, wait=True, callback=None, timeout=common.PV_DEFAULT_TIMEOUT, notify=None, data_type=None, data_count=None): """ Write a new value. Optionally, request confirmation from the server. Can do one or both of: - Block while waiting for the response, and return it. - Pass the response to callback, with or without blocking. Parameters ---------- data : str, int, or float or any Iterable of these Value(s) to write. wait : boolean If True (default) block until a matching WriteNotifyResponse is received from the server. Raises CaprotoTimeoutError if that response is not received within the time specified by the `timeout` parameter. callback : callable or None Called with the WriteNotifyResponse as its argument when received. timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. notify : boolean or None If None (default), set to True if wait=True or callback is set. Can be manually set to True or False. Will raise ValueError if set to False while wait=True or callback is set. data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Write specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. """ if timeout is common.PV_DEFAULT_TIMEOUT: timeout = self.timeout cm, chan = self._circuit_manager, self._channel if notify is None: notify = (wait or callback is not None) ioid = cm._ioid_counter() command = chan.write(data, ioid=ioid, notify=notify, data_type=data_type, data_count=data_count) if notify: event = threading.Event() ioid_info = dict(event=event, pv=self, request=command) if callback is not None: ioid_info['callback'] = callback cm.ioids[ioid] = ioid_info deadline = time.monotonic() + timeout if timeout is not None else None ioid_info['deadline'] = deadline # do not need to lock this, locking happens in circuit command else: if wait or callback is not None: raise CaprotoValueError("Must set notify=True in order to use " "`wait` or `callback` because, without a " "notification of 'put-completion' from the " "server, there is nothing to wait on or to " "trigger a callback.") cm.send(command, extra={'pv': self.name}) if not wait: return # The circuit_manager will put a reference to the response into # ioid_info and then set event. if not event.wait(timeout=timeout): if cm.dead.is_set(): # This circuit has died sometime during this function call. # The exception raised here will be caught by # @ensure_connected, which will retry the function call a # in hopes of getting a working circuit until our `timeout` has # been used up. raise DeadCircuitError() host, port = cm.circuit.address raise CaprotoTimeoutError( f"Server at {host}:{port} did " f"not respond to attempt to write to channel named " f"{self.name!r} within {float(timeout):.3}-second timeout. " f"The ioid of the expected response is {ioid}." ) return ioid_info['response']
[docs] def subscribe(self, data_type=None, data_count=None, low=0.0, high=0.0, to=0.0, mask=None): """ Start a new subscription to which user callback may be added. Parameters ---------- data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. low, high, to : float, optional deprecated by Channel Access, not yet implemented by caproto mask : SubscriptionType, optional Subscribe to selective updates. Returns ------- subscription : Subscription Examples -------- Define a subscription. >>> sub = pv.subscribe() Add a user callback. The subscription will be transparently activated (i.e. an ``EventAddRequest`` will be sent) when the first user callback is added. >>> sub.add_callback(my_func) Multiple callbacks may be added to the same subscription. >>> sub.add_callback(another_func) See the docstring for :class:`Subscription` for more. """ # A Subscription is uniquely identified by the Signature created by its # args and kwargs. bound = SUBSCRIBE_SIG.bind(data_type, data_count, low, high, to, mask) key = tuple(bound.arguments.items()) try: sub = self.subscriptions[key] except KeyError: sub = Subscription(self, data_type, data_count, low, high, to, mask) self.subscriptions[key] = sub # The actual EPICS messages will not be sent until the user adds # callbacks via sub.add_callback(user_func). return sub
[docs] def unsubscribe_all(self): "Clear all subscriptions. (Remove all user callbacks from them.)" for sub in self.subscriptions.values(): sub.clear()
[docs] @ensure_connected def time_since_last_heard(self, timeout=common.PV_DEFAULT_TIMEOUT): """ Seconds since last message from the server that provides this channel. The time is reset to 0 whenever we receive a TCP message related to user activity *or* a Beacon. Servers are expected to send Beacons at regular intervals. If we do not receive either a Beacon or TCP message, we initiate an Echo over TCP, to which the server is expected to promptly respond. Therefore, the time reported here should not much exceed ``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that environment variable) if the server is healthy. If the server fails to send a Beacon on schedule *and* fails to reply to an Echo, the server is assumed dead. A warning is issued, and all PVs are disconnected to initiate a reconnection attempt. Parameters ---------- timeout : number or None, optional Seconds to wait before a CaprotoTimeoutError is raised. Default is ``PV.timeout``, which falls back to ``PV.context.timeout`` if not set. If None, never timeout. """ address = self.circuit_manager.circuit.address return self.context.broadcaster.time_since_last_heard()[address]
# def __hash__(self): # return id((self.context, self.circuit_manager, self.name)) class CallbackHandler: def __init__(self, pv): # NOTE: not a WeakValueDictionary or WeakSet as PV is unhashable... self.callbacks = {} self.pv = pv self._callback_id = 0 self.callback_lock = threading.RLock() self._last_call_values = None def add_callback(self, func, run=False): def removed(_): self.remove_callback(cb_id) # defined below inside the lock if inspect.ismethod(func): ref = weakref.WeakMethod(func, removed) else: # TODO: strong reference to non-instance methods? ref = weakref.ref(func, removed) with self.callback_lock: cb_id = self._callback_id self._callback_id += 1 self.callbacks[cb_id] = ref if run and self._last_call_values is not None: with self.callback_lock: args, kwargs = self._last_call_values self.process(*args, **kwargs) return cb_id def remove_callback(self, token): with self.callback_lock: self.callbacks.pop(token, None) def process(self, *args, **kwargs): """ This is a fast operation that submits jobs to the Context's ThreadPoolExecutor and then returns. """ to_remove = [] with self.callback_lock: callbacks = list(self.callbacks.items()) self._last_call_values = (args, kwargs) for cb_id, ref in callbacks: callback = ref() if callback is None: to_remove.append(cb_id) continue try: self.pv.circuit_manager.user_callback_executor.submit( callback, *args, **kwargs) except RuntimeError: if self.pv.circuit_manager.dead.is_set(): # if the circuit is dead, so is the executor forgive # and exit return # otherwise raise and let someone else deal with the # mess raise with self.callback_lock: for remove_id in to_remove: self.callbacks.pop(remove_id, None)
[docs]class Subscription(CallbackHandler): """ Represents one subscription, specified by a PV and configurational parameters It may fan out to zero, one, or multiple user-registered callback functions. This object should never be instantiated directly by user code; rather it should be made by calling the ``subscribe()`` method on a ``PV`` object. """ def __init__(self, pv, data_type, data_count, low, high, to, mask): super().__init__(pv) # Stash everything, but do not send any EPICS messages until the first # user callback is attached. self.data_type = data_type self.data_count = data_count self.low = low self.high = high self.to = to self.mask = mask self.subscriptionid = None self.most_recent_response = None self.needs_reactivation = False # This is related to back-compat for user callbacks that have the old # signature, f(response). self.__wrapper_weakrefs = set() @property def log(self): return self.pv.log def __repr__(self): return f"<Subscription to {self.pv.name!r}, id={self.subscriptionid}>" def _subscribe(self, timeout=common.PV_DEFAULT_TIMEOUT): """This is called automatically after the first callback is added. """ cm = self.pv.circuit_manager if cm is None: # We are currently disconnected (perhaps have not yet connected). # When the PV connects, this subscription will be added. with self.callback_lock: self.needs_reactivation = True else: # We are (or very recently were) connected. In the rare event # where cm goes dead in the interim, subscription will be retried # by the activation loop. ctx = cm.context with ctx.subscriptions_lock: ctx.subscriptions_to_activate[cm].add(self) ctx.activate_subscriptions_now.set() @ensure_connected def compose_command(self, timeout=common.PV_DEFAULT_TIMEOUT): "This is used by the Context to re-subscribe in bulk after dropping." with self.callback_lock: if not self.callbacks: return None cm, chan = self.pv._circuit_manager, self.pv._channel subscriptionid = cm._subscriptionid_counter() command = chan.subscribe(data_type=self.data_type, data_count=self.data_count, low=self.low, high=self.high, to=self.to, mask=self.mask, subscriptionid=subscriptionid) subscriptionid = command.subscriptionid self.subscriptionid = subscriptionid # The circuit_manager needs to know the subscriptionid so that it can # route responses to this request. cm.subscriptions[subscriptionid] = self return command
[docs] def clear(self): """ Remove all callbacks. """ with self.callback_lock: for cb_id in list(self.callbacks): self.remove_callback(cb_id)
# Once self.callbacks is empty, self.remove_callback calls # self._unsubscribe for us. def _unsubscribe(self, timeout=common.PV_DEFAULT_TIMEOUT): """ This is automatically called if the number of callbacks goes to 0. """ with self.callback_lock: if self.subscriptionid is None: # Already unsubscribed. return subscriptionid = self.subscriptionid self.subscriptionid = None self.most_recent_response = None self.pv.circuit_manager.subscriptions.pop(subscriptionid, None) chan = self.pv.channel if chan and chan.states[ca.CLIENT] is ca.CONNECTED: try: command = self.pv.channel.unsubscribe(subscriptionid) except ca.CaprotoKeyError: pass else: self.pv.circuit_manager.send(command, extra={'pv': self.pv.name}) def process(self, command): # TODO here i think we can decouple PV update rates and callback # handling rates, if desirable, to not bog down performance. # As implemented below, updates are blocking further messages from # the CA servers from processing. (-> ThreadPool, etc.) pv = self.pv super().process(self, command) self.log.debug("%r: %r", pv.name, command) self.most_recent_response = command
[docs] def add_callback(self, func): """ Add a callback to receive responses. Parameters ---------- func : callable Expected signature: ``func(sub, response)``. The signature ``func(response)`` is also supported for backward-compatibility but will issue warnings. Support will be removed in a future release of caproto. Returns ------- token : int Integer token that can be passed to :meth:`remove_callback`. .. versionchanged:: 0.5.0 Changed the expected signature of ``func`` from ``func(response)`` to ``func(sub, response)``. """ # Handle func with signature func(response) for back-compat. func = adapt_old_callback_signature(func, self.__wrapper_weakrefs) with self.callback_lock: was_empty = not self.callbacks cb_id = super().add_callback(func) most_recent_response = self.most_recent_response if was_empty: # This is the first callback. Set up a subscription, which # should elicit a response from the server soon giving the # current value to this func (and any other funcs added in the # mean time). self._subscribe() else: # This callback is piggy-backing onto an existing subscription. # Send it the most recent response, unless we are still waiting # for that first response from the server. if most_recent_response is not None: try: func(self, most_recent_response) except Exception: self.log.exception( "Exception raised during processing most recent " "response %r with new callback %r", most_recent_response, func) return cb_id
[docs] def remove_callback(self, token): """ Remove callback using token that was returned by :meth:`add_callback`. Parameters ---------- token : integer Token returned by :meth:`add_callback`. """ with self.callback_lock: super().remove_callback(token) if not self.callbacks: # Go dormant. self._unsubscribe() self.most_recent_response = None self.needs_reactivation = False
def __del__(self): try: self.clear() except TimeoutError: pass
[docs]class Batch: """ Accumulate requests and then issue them all in batch. Parameters ---------- timeout : number or None Number of seconds to wait before ignoring late responses. Default is 2. Examples -------- Read some PVs in batch and stash the readings in a dictionary as they come in. >>> results = {} >>> def stash_result(name, response): ... results[name] = response.data ... >>> with Batch() as b: ... for pv in pvs: ... b.read(pv, functools.partial(stash_result, pv.name)) ... # The requests are sent upon exiting this 'with' block. ... The ``results`` dictionary will be populated as responses come in. """ def __init__(self, timeout=2): self.timeout = timeout self._commands = defaultdict(list) # map each circuit to commands self._ioid_infos = [] def __enter__(self): return self
[docs] def read(self, pv, callback, data_type=None, data_count=None): """Request a fresh reading as part of a batched request. Notice that, unlike :meth:`PV.read`, the callback is required. (There is no other way to get the result back from a batched read.) Parameters ---------- pv : PV callback : callable Expected signature: ``f(response)`` data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. """ ioid = pv.circuit_manager._ioid_counter() command = pv.channel.read(ioid=ioid, data_type=data_type, data_count=data_count, notify=True) self._commands[pv.circuit_manager].append(command) # Stash the ioid to match the response to the request. # The request is used in the logging in __exit__. It is not needed # by the circuit. ioid_info = dict(callback=callback, pv=pv, request=command) pv.circuit_manager.ioids[ioid] = ioid_info self._ioid_infos.append(ioid_info)
[docs] def write(self, pv, data, callback=None, data_type=None, data_count=None): """Write a new value as part of a batched request. Parameters ---------- pv : PV data : str, int, or float or any Iterable of these Value(s) to write. callback : callable Expected signature: ``f(response)`` data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional Request specific data type or a class of data types, matched to the channel's native data type. Default is Channel's native data type. data_count : integer, optional Requested number of values. Default is the channel's native data count. """ ioid = pv.circuit_manager._ioid_counter() command = pv.channel.write(data=data, ioid=ioid, data_type=data_type, data_count=data_count, notify=callback is not None) self._commands[pv.circuit_manager].append(command) if callback: # Stash the ioid to match the response to the request. # The request is used in the logging in __exit__. It is not needed # by the circuit. ioid_info = dict(callback=callback, pv=pv, request=command) pv.circuit_manager.ioids[ioid] = ioid_info self._ioid_infos.append(ioid_info)
def __exit__(self, exc_type, exc_value, traceback): timeout = self.timeout deadline = time.monotonic() + timeout if timeout is not None else None for ioid_info in self._ioid_infos: ioid_info['deadline'] = deadline for circuit_manager, commands in self._commands.items(): circuit_manager.send(*commands)
# The signature of caproto._circuit.ClientChannel.subscribe, which is used to # resolve the (args, kwargs) of a Subscription into a unique key. SUBSCRIBE_SIG = Signature([ Parameter('data_type', Parameter.POSITIONAL_OR_KEYWORD, default=None), Parameter('data_count', Parameter.POSITIONAL_OR_KEYWORD, default=None), Parameter('low', Parameter.POSITIONAL_OR_KEYWORD, default=0), Parameter('high', Parameter.POSITIONAL_OR_KEYWORD, default=0), Parameter('to', Parameter.POSITIONAL_OR_KEYWORD, default=0), Parameter('mask', Parameter.POSITIONAL_OR_KEYWORD, default=None)])