# Regarding threads...
# The SharedBroadcaster has:
# - UDP socket SelectorThread
# - UDP command processing
# - forever retrying search requests for disconnected PV
# The Context has:
# - process search results
# - TCP socket SelectorThread
# - restart subscriptions
# The VirtualCircuit has:
# - ThreadPoolExecutor for processing user callbacks on read, write, subscribe
import array
import concurrent.futures
import errno
import functools
import getpass
import inspect
import logging
import random
import selectors
import socket
import threading
import time
import weakref
from collections import defaultdict, deque
from inspect import Parameter, Signature
from queue import Empty, Queue
import caproto as ca
from .._constants import (MAX_ID, RESPONSIVENESS_TIMEOUT,
SEARCH_MAX_DATAGRAM_BYTES, STALE_SEARCH_EXPIRATION)
from .._utils import (CaprotoError, CaprotoKeyError, CaprotoNetworkError,
CaprotoRuntimeError, CaprotoTimeoutError,
CaprotoTypeError, CaprotoValueError, ThreadsafeCounter,
adapt_old_callback_signature, batch_requests,
safe_getsockname, socket_bytes_available)
from ..client import common
ch_logger = logging.getLogger('caproto.ch')
search_logger = logging.getLogger('caproto.bcast.search')
class DeadCircuitError(CaprotoError):
...
def ensure_connected(func):
@functools.wraps(func)
def inner(self, *args, **kwargs):
if isinstance(self, PV):
pv = self
elif isinstance(self, Subscription):
pv = self.pv
else:
raise CaprotoTypeError("ensure_connected is intended to decorate "
"methods of PV and Subscription.")
# timeout may be decremented during disconnection-retry loops below.
# Keep a copy of the original 'raw_timeout' for use in error messages.
raw_timeout = timeout = kwargs.get('timeout', pv.timeout)
if timeout is not None:
deadline = time.monotonic() + timeout
with pv._in_use:
# If needed, reconnect. Do this inside the lock so that we don't
# try to do this twice. (No other threads that need this lock
# can proceed until the connection is ready anyway!)
if pv._idle:
# The Context should have been maintaining a working circuit
# for us while this was idle. We just need to re-create the
# Channel.
ready = pv.circuit_ready.wait(timeout=timeout)
if not ready:
raise CaprotoTimeoutError(
f"{pv} could not connect within "
f"{float(raw_timeout):.3}-second timeout.")
with pv.component_lock:
cm = pv.circuit_manager
cid = cm.circuit.new_channel_id()
chan = ca.ClientChannel(pv.name, cm.circuit, cid=cid)
cm.channels[cid] = chan
cm.pvs[cid] = pv
pv.circuit_manager.send(chan.create(), extra={'pv': pv.name})
self._idle = False
# increment the usage at the very end in case anything
# goes wrong in the block of code above this.
pv._usages += 1
try:
for _ in range(common.CIRCUIT_DEATH_ATTEMPTS):
# On each iteration, subtract the time we already spent on any
# previous attempts.
if timeout is not None:
timeout = deadline - time.monotonic()
ready = pv.channel_ready.wait(timeout=timeout)
if not ready:
raise CaprotoTimeoutError(
f"{pv} could not connect within "
f"{float(raw_timeout):.3}-second timeout.")
if timeout is not None:
timeout = deadline - time.monotonic()
kwargs['timeout'] = timeout
cm = pv.circuit_manager
try:
return func(self, *args, **kwargs)
except DeadCircuitError:
# Something in func tried operate on the circuit after
# it died. The context will automatically build us a
# new circuit. Try again.
self.log.debug('Caught DeadCircuitError. '
'Retrying %s.', func.__name__)
continue
except TimeoutError:
# The circuit may have died after func was done calling
# methods on it but before we received some response we
# were expecting. The context will automatically build
# us a new circuit. Try again.
if cm.dead.is_set():
self.log.debug('Caught TimeoutError due to dead '
'circuit. '
'Retrying %s.', func.__name__)
continue
# The circuit is fine -- this is a real error.
raise
finally:
with pv._in_use:
pv._usages -= 1
pv._in_use.notify_all()
return inner
class ThreadingClientException(CaprotoError):
...
class DisconnectedError(ThreadingClientException):
...
class ContextDisconnectedError(ThreadingClientException):
...
class SelectorThread:
"""
This is used internally by the Context and the VirtualCircuitManager.
"""
def __init__(self, *, parent=None):
self.thread = None # set by the `start` method
self._close_event = threading.Event()
self.selector = selectors.DefaultSelector()
self._register_event = threading.Event()
self._socket_map_lock = threading.RLock()
self.objects = weakref.WeakValueDictionary()
self.socket_to_id = {}
self._register_sockets = {} # {socket: object_id}
self._unregister_sockets = set()
self._object_id = 0
self._socket_count = 0
if parent is not None:
# Stop the selector if the parent goes out of scope
self._parent = weakref.ref(parent, lambda obj: self.stop())
@property
def running(self):
'''Selector thread is running'''
return not self._close_event.is_set()
def stop(self):
self._close_event.set()
# In case we're waiting for the first socket to be added:
self._register_event.set()
def start(self):
if self._close_event.is_set():
raise CaprotoRuntimeError("Cannot be restarted once stopped.")
self.thread = threading.Thread(target=self, daemon=True,
name='selector')
self.thread.start()
def add_socket(self, sock, target_obj):
assert isinstance(sock, socket.socket)
with self._socket_map_lock:
if sock in self.socket_to_id:
raise CaprotoValueError('Socket already added')
sock.setblocking(False)
# assumption: only one sock per object
self._object_id += 1
self.objects[self._object_id] = target_obj
self.socket_to_id[sock] = self._object_id
self._register_sockets[sock] = self._object_id
weakref.finalize(target_obj,
lambda sock=sock: self.remove_socket(sock))
# self.log.debug('Socket %s was added (obj %s)', sock, target_obj)
def remove_socket(self, sock):
with self._socket_map_lock:
if sock not in self.socket_to_id:
return
obj_id = self.socket_to_id.pop(sock)
obj = self.objects.pop(obj_id, None)
if obj is not None:
obj.received(b'', None)
try:
# removed before it was even added...
# self.log.debug('Socket %s was removed before it was added '
# '(obj = %s)', sock, obj)
self._register_sockets.pop(sock)
except KeyError:
# self.log.debug('Socket %s was removed '
# '(obj = %s)', sock, obj)
self._unregister_sockets.add(sock)
def __call__(self):
'''Selector poll loop'''
avail_buf = array.array('i', [0])
while self.running:
with self._socket_map_lock:
for sock in self._unregister_sockets:
self.selector.unregister(sock)
self._socket_count -= len(self._unregister_sockets)
self._unregister_sockets.clear()
for sock, obj_id in self._register_sockets.items():
self.selector.register(sock, selectors.EVENT_READ,
data=obj_id)
self._socket_count += len(self._register_sockets)
self._register_sockets.clear()
if self._socket_count == 0:
if self._register_event.wait(timeout=0.1):
self._register_event.clear()
continue
events = self.selector.select(timeout=0.1)
with self._socket_map_lock:
if self._unregister_sockets:
# some sockets may be affected here; try again
continue
object_and_socket = [(self.objects[key.data], key.fileobj)
for key, mask in events]
for obj, sock in object_and_socket:
if sock in self._unregister_sockets:
continue
# TODO: consider thread pool for recv and command_loop
try:
bytes_available = socket_bytes_available(
sock, available_buffer=avail_buf)
bytes_recv, address = sock.recvfrom(bytes_available)
except ConnectionResetError as ex:
if sock.type == socket.SOCK_DGRAM:
# Win32: "On a UDP-datagram socket this error indicates
# a previous send operation resulted in an ICMP Port
# Unreachable message."
#
# https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
obj.log.debug(
"UDP socket indicates previous send failed %s: %s",
obj,
ex
)
continue
obj.log.error("Removing %s due to %s (%s)", obj, ex, ex.errno)
self.remove_socket(sock)
except OSError as ex:
if ex.errno != errno.EAGAIN:
# register as a disconnection
obj.log.error('Removing %s due to %s (%s)', obj, ex,
ex.errno)
self.remove_socket(sock)
continue
try:
# Let objects handle disconnection by return value
if obj.received(bytes_recv, address) is ca.DISCONNECTED:
obj.log.debug('Removing %s = %s after DISCONNECTED '
'return value', sock, obj)
self.remove_socket(sock)
# TODO: consider adding specific DISCONNECTED instead
# of b'' sent to disconnected sockets
except Exception as ex:
if sock.type == socket.SOCK_DGRAM:
obj.log.exception(
'UDP socket for %s failed on receipt of '
'new data: %s', obj, ex)
else:
obj.log.exception(
'Removing %s due to an internal error on receipt of '
'new data: %s', obj, ex)
self.remove_socket(sock)
[docs]class SharedBroadcaster:
def __init__(self, *, registration_retry_time=10.0):
'''
A broadcaster client which can be shared among multiple Contexts
Parameters
----------
registration_retry_time : float, optional
The time, in seconds, between attempts made to register with the
repeater. Default is 10.
'''
self.environ = ca.get_environment_variables()
self.ca_server_port = self.environ['EPICS_CA_SERVER_PORT']
self.udp_sock = ca.bcast_socket()
# Must bind or getsocketname() will raise on Windows.
# See https://github.com/caproto/caproto/issues/514.
self.udp_sock.bind(('', 0))
self._search_lock = threading.RLock()
self._retry_unanswered_searches_thread = None
# This Event ensures that we send a registration request before our
# first search request.
self._searching_enabled = threading.Event()
# This Event lets us nudge the search thread when the user asks for new
# PVs (via Context.get_pvs).
self._search_now = threading.Event()
self.search_results = {} # map name to (time, address)
# map search id (cid) to [name, queue, last_search_time, retirement_deadline]
self.unanswered_searches = {}
self.server_protocol_versions = {} # map address to protocol version
self._id_counter = ThreadsafeCounter(
initial_value=random.randint(0, MAX_ID),
dont_clash_with=self.unanswered_searches,
)
self.listeners = weakref.WeakSet()
self.broadcaster = ca.Broadcaster(our_role=ca.CLIENT)
self.broadcaster.client_address = safe_getsockname(self.udp_sock)
self.log = logging.LoggerAdapter(
self.broadcaster.log, {'role': 'CLIENT'})
self.search_log = logging.LoggerAdapter(
logging.getLogger('caproto.bcast.search'), {'role': 'CLIENT'})
self.command_bundle_queue = Queue()
self.last_beacon = {}
self.last_beacon_interval = {}
# an event to tear down and clean up the broadcaster
self._close_event = threading.Event()
self.selector = SelectorThread(parent=self)
self.selector.add_socket(self.udp_sock, self)
self.selector.start()
self._command_thread = threading.Thread(target=self.command_loop,
daemon=True, name='command')
self._command_thread.start()
self._check_for_unresponsive_servers_thread = threading.Thread(
target=self._check_for_unresponsive_servers,
daemon=True, name='check_for_unresponsive_servers')
self._check_for_unresponsive_servers_thread.start()
self._registration_retry_time = registration_retry_time
self._registration_last_sent = 0
try:
# Always attempt registration on initialization, but allow failures
self._register()
except Exception:
self.log.exception('Broadcaster registration failed on init')
def _should_attempt_registration(self):
'Whether or not a registration attempt should be tried'
if self.udp_sock is None:
# This broadcaster does not currently support being revived from
# a disconnected state. Do not attempt registration if the
# __init__-defined socket has been removed.
return False
if (self.broadcaster.registered or
self._registration_retry_time is None):
return False
since_last_attempt = time.monotonic() - self._registration_last_sent
if since_last_attempt < self._registration_retry_time:
return False
return True
def _register(self):
'Send a registration request to the repeater'
self._registration_last_sent = time.monotonic()
commands = [self.broadcaster.register()]
bytes_to_send = self.broadcaster.send(*commands)
addr = (ca.get_local_address(), self.environ['EPICS_CA_REPEATER_PORT'])
tags = {
'role': 'CLIENT',
'our_address': self.broadcaster.client_address,
'direction': '--->>>',
'their_address': addr,
}
self.broadcaster.log.debug(
'%d commands %dB', len(commands), len(bytes_to_send), extra=tags)
try:
self.udp_sock.sendto(bytes_to_send, addr)
except (OSError, AttributeError) as ex:
host, specified_port = addr
self.log.exception('%s while sending %d bytes to %s:%d',
ex, len(bytes_to_send), host, specified_port)
self._searching_enabled.set()
[docs] def new_id(self):
return self._id_counter()
[docs] def add_listener(self, listener):
with self._search_lock:
if self._retry_unanswered_searches_thread is None:
self._retry_unanswered_searches_thread = threading.Thread(
target=self._retry_unanswered_searches, daemon=True,
name='retry')
self._retry_unanswered_searches_thread.start()
self.listeners.add(listener)
[docs] def remove_listener(self, listener):
try:
self.listeners.remove(listener)
except KeyError:
pass
if not self.listeners:
self.disconnect()
[docs] def disconnect(self, *, wait=True):
if self.udp_sock is not None:
self.selector.remove_socket(self.udp_sock)
self.udp_sock.close()
self.udp_sock = None
self._close_event.set()
with self._search_lock:
self.search_results.clear()
self._registration_last_sent = 0
self._searching_enabled.clear()
self.broadcaster.disconnect()
self.selector.stop()
if wait:
self._command_thread.join()
self.selector.thread.join()
self._retry_unanswered_searches_thread.join()
[docs] def send(self, *commands):
"""
Process a command and transport it over the UDP socket.
"""
bytes_to_send = self.broadcaster.send(*commands)
tags = {'role': 'CLIENT',
'our_address': self.broadcaster.client_address,
'direction': '--->>>'}
for host_tuple in ca.get_client_address_list():
tags['their_address'] = host_tuple
self.broadcaster.log.debug(
'%d commands %dB',
len(commands), len(bytes_to_send), extra=tags)
sock = self.udp_sock
if sock is None:
return
try:
sock.sendto(bytes_to_send, host_tuple)
except OSError as ex:
host, specified_port = host_tuple
raise CaprotoNetworkError(
f'{ex} while sending {len(bytes_to_send)} bytes to '
f'{host}:{specified_port}') from ex
[docs] def get_cached_search_result(self, name, *,
threshold=STALE_SEARCH_EXPIRATION):
'Returns address if found, raises KeyError if missing or stale.'
with self._search_lock:
address, timestamp = self.search_results[name]
# this block of code is only to re-fresh the time found on
# any PVs. If we can find any context which has any circuit which
# has any channel talking to this PV name then it is not stale so
# re-up the timestamp to now.
if time.monotonic() - timestamp > threshold:
# TODO this is very inefficient
for context in self.listeners:
for cm in context.circuit_managers.values():
if cm.connected and name in cm.all_created_pvnames:
# A valid connection exists in one of our clients, so
# ignore the stale result status
with self._search_lock:
self.search_results[name] = (address, time.monotonic())
# TODO verify that addr matches address
return address
with self._search_lock:
# Clean up expired result.
self.search_results.pop(name, None)
raise CaprotoKeyError(f'{name!r}: stale search result')
return address
[docs] def search(self, results_queue, names, *, timeout=2):
"""
Search for PV names.
The ``results_queue`` will receive ``(address, names)`` (the address of
a server and a list of name(s) that it has) when results are received.
If a cached result is already known, it will be put immediately into
``results_queue`` from this thread during this method's execution.
If not, a SearchRequest will be sent from another thread. If necessary,
the request will be re-sent periodically. When a matching response is
received (by yet another thread) ``(address, names)`` will be put into
the ``results_queue``.
"""
if self._should_attempt_registration():
self._register()
new_id = self.new_id
unanswered_searches = self.unanswered_searches
with self._search_lock:
# We have have already searched for these names recently.
# Filter `pv_names` down to a subset, `needs_search`.
needs_search = []
use_cached_search = defaultdict(list)
for name in names:
try:
address = self.get_cached_search_result(name)
except KeyError:
needs_search.append(name)
else:
use_cached_search[address].append(name)
for address, names in use_cached_search.items():
results_queue.put((address, names))
# Generate search_ids and stash them on Context state so they can
# be used to match SearchResponses with SearchRequests.
search_ids = []
# Search requests that are past their retirement deadline with no
# results will be searched for less frequently.
retirement_deadline = time.monotonic() + common.SEARCH_RETIREMENT_AGE
for name in needs_search:
search_id = new_id()
search_ids.append(search_id)
# The value is a list because we mutate it to update the
# retirement deadline sometimes.
unanswered_searches[search_id] = [name, results_queue,
0, retirement_deadline]
self._search_now.set()
[docs] def cancel(self, *names):
"""
Cancel searches for these names.
Parameters
----------
*names : strings
any number of PV names
Any PV instances that were awaiting these results will be stuck until
:meth:`get_pvs` is called again.
"""
with self._search_lock:
for search_id, item in list(self.unanswered_searches.items()):
if item[0] in names:
del self.unanswered_searches[search_id]
[docs] def search_now(self):
"""
Force the Broadcaster to reissue all unanswered search requests now.
Left to its own devices, the Broadcaster will do this at regular
intervals automatically. This method is intended primarily for
debugging and should not be needed in normal use.
"""
self._search_now.set()
[docs] def received(self, bytes_recv, address):
"Receive and process and next command broadcasted over UDP."
if bytes_recv:
commands = self.broadcaster.recv(bytes_recv, address)
if commands:
self.command_bundle_queue.put(commands)
return 0
[docs] def command_loop(self):
# Receive commands in 'bundles' (corresponding to the contents of one
# UDP datagram). Match SearchResponses to their SearchRequests, and
# put (address, (name1, name2, name3, ...)) into a queue. The receiving
# end of that queue is held by Context._process_search_results.
# Save doing a 'self' lookup in the inner loop.
search_results = self.search_results
server_protocol_versions = self.server_protocol_versions
unanswered_searches = self.unanswered_searches
queues = defaultdict(list)
results_by_cid = deque(maxlen=1000)
self.log.debug('Broadcaster command loop is running.')
while not self._close_event.is_set():
try:
commands = self.command_bundle_queue.get(timeout=0.5)
except Empty:
# By restarting the loop, we will first check that we are not
# supposed to shut down the thread before we go back to
# waiting on the queue again.
continue
try:
self.broadcaster.process_commands(commands)
except ca.CaprotoError as ex:
self.log.warning('Broadcaster command error', exc_info=ex)
continue
queues.clear()
now = time.monotonic()
tags = {'role': 'CLIENT',
'our_address': self.broadcaster.client_address,
'direction': '<<<---'}
for command in commands:
if isinstance(command, ca.Beacon):
now = time.monotonic()
address = (command.address, command.server_port)
tags['their_address'] = address
if address not in self.last_beacon:
# We made a new friend!
self.broadcaster.log.info("Watching Beacons from %s:%d",
*address, extra=tags)
self._new_server_found()
else:
interval = now - self.last_beacon[address]
if interval < self.last_beacon_interval.get(address, 0) / 4:
# Beacons are arriving *faster*? The server at this
# address may have restarted.
self.broadcaster.log.info(
"Beacon anomaly: %s:%d may have restarted.",
*address, extra=tags)
self._new_server_found()
self.last_beacon_interval[address] = interval
self.last_beacon[address] = now
elif isinstance(command, ca.SearchResponse):
cid = command.cid
try:
with self._search_lock:
name, queue, *_ = unanswered_searches.pop(cid)
except KeyError:
# This is a redundant response, which the EPICS
# spec tells us to ignore. (The first responder
# to a given request wins.)
try:
_, name = next(r for r in results_by_cid if r[0] == cid)
except StopIteration:
continue
else:
with self._search_lock:
if name in search_results:
accepted_address, _ = search_results[name]
new_address = ca.extract_address(command)
if new_address != accepted_address:
search_logger.warning(
"PV %s with cid %d found on multiple "
"servers. Accepted address is %s:%d. "
"Also found on %s:%d",
name, cid, *accepted_address, *new_address,
extra={'pv': name,
'their_address': accepted_address,
'our_address': self.broadcaster.client_address})
else:
results_by_cid.append((cid, name))
address = ca.extract_address(command)
queues[queue].append(name)
# Cache this to save time on future searches.
# (Entries expire after STALE_SEARCH_EXPIRATION.)
with self._search_lock:
search_results[name] = (address, now)
server_protocol_versions[address] = command.version
# Send the search results to the Contexts that asked for
# them. This is probably more general than is has to be but
# I'm playing it safe for now.
if queues:
for queue, names in queues.items():
queue.put((address, names))
self.log.debug('Broadcaster command loop has exited.')
def _new_server_found(self):
# Bring all the unanswered seraches out of retirement
# to see if we have a new match.
retirement_deadline = time.monotonic() + common.SEARCH_RETIREMENT_AGE
with self._search_lock:
for item in self.unanswered_searches.values():
# give new age-out deadline
item[-1] = retirement_deadline
[docs] def time_since_last_heard(self):
"""
Map each known server address to seconds since its last message.
The time is reset to 0 whenever we receive a TCP message related to
user activity *or* a Beacon. Servers are expected to send Beacons at
regular intervals. If we do not receive either a Beacon or TCP message,
we initiate an Echo over TCP, to which the server is expected to
promptly respond.
Therefore, the time reported here should not much exceed
``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that
environment variable) if the server is healthy.
If the server fails to send a Beacon on schedule *and* fails to reply to
an Echo, the server is assumed dead. A warning is issued, and all PVs
are disconnected to initiate a reconnection attempt.
"""
return {address: time.monotonic() - t for address, t in self._last_heard.items()}
def _check_for_unresponsive_servers(self):
self.log.debug('Broadcaster check for unresponsive servers loop is running.')
MARGIN = 1 # extra time (seconds) allowed between Beacons
checking = dict() # map address to deadline for check to resolve
servers = defaultdict(weakref.WeakSet) # map address to VirtualCircuitManagers
last_heard = dict() # map address to time of last response
self._last_heard = last_heard
# Make locals to save getattr lookups in the loop.
last_beacon = self.last_beacon
listeners = self.listeners
while not self._close_event.is_set():
servers.clear()
last_heard.clear()
now = time.monotonic()
# We are interested in identifying servers that we have not heard
# from since some time cutoff in the past.
cutoff = now - (self.environ['EPICS_CA_CONN_TMO'] + MARGIN)
# Map each server address to VirtualCircuitManagers connected to
# that address, across all Contexts ("listeners").
for listener in listeners:
for (address, _), circuit_manager in listener.circuit_managers.items():
servers[address].add(circuit_manager)
# When is the last time we heard from each server, either via a
# Beacon or from TCP packets related to user activity or any
# circuit?
for address, circuit_managers in servers.items():
last_tcp_receipt = (cm.last_tcp_receipt for cm in circuit_managers)
last_heard[address] = max((last_beacon.get(address, 0),
*last_tcp_receipt))
# If is has been too long --- and if we aren't already checking
# on this server --- try to prompt a response over TCP by
# sending an EchoRequest.
if last_heard[address] < cutoff and address not in checking:
# Record that we are checking on this address and set a
# deadline for a response.
checking[address] = now + RESPONSIVENESS_TIMEOUT
tags = {'role': 'CLIENT',
'their_address': address,
'our_address': self.broadcaster.client_address,
'direction': '--->>>'}
self.broadcaster.log.debug(
"Missed Beacons from %s:%d. Sending EchoRequest to "
"check that server is responsive.", *address, extra=tags)
# Send on all circuits. One might be less backlogged
# with queued commands than the others and thus able to
# respond faster. In the majority of cases there will only
# be one circuit per server anyway, so this is a minor
# distinction.
for circuit_manager in circuit_managers:
try:
circuit_manager.send(ca.EchoRequest())
except Exception:
# Send failed. Server is likely dead, but we'll
# catch that shortly; no need to handle it
# specially here.
pass
# Check to see if any of our ongoing checks have resolved or
# failed to resolve within the allowed response window.
for address, deadline in list(checking.items()):
if last_heard[address] > cutoff:
# It's alive!
checking.pop(address)
elif deadline < now:
# No circuit connected to the server at this address has
# sent Beacons or responded to the EchoRequest. We assume
# it is unresponsive. The EPICS specification says the
# behavior is undefined at this point. We choose to
# disconnect all circuits from that server so that PVs can
# attempt to connect to a new server, such as a failover
# backup.
for circuit_manager in servers[address]:
if circuit_manager.connected:
circuit_manager.log.warning(
"Server at %s:%d is unresponsive. "
"Disconnecting circuit manager %r. PVs will "
"automatically begin attempting to reconnect "
"to a responsive server.",
*address, circuit_manager)
circuit_manager._disconnected()
checking.pop(address)
# else:
# # We are still waiting to give the server time to respond
# # to the EchoRequest.
time.sleep(0.5)
self.log.debug('Broadcaster check for unresponsive servers loop has exited.')
def _retry_unanswered_searches(self):
"""
Periodically (re-)send a SearchRequest for all unanswered searches.
"""
# Each time new searches are added, the self._search_now Event is set,
# and we reissue *all* unanswered searches.
#
# We then frequently retry the unanswered searches that are younger
# than SEARCH_RETIREMENT_AGE, backing off from an interval of
# MIN_RETRY_SEARCHES_INTERVAL to MAX_RETRY_SEARCHES_INTERVAL. The
# interval is reset to MIN_RETRY_SEARCHES_INTERVAL each time new
# searches are added.
#
# For the searches older than SEARCH_RETIREMENT_AGE, we adopt a slower
# period to minimize network traffic. We only resend every
# RETRY_RETIRED_SEARCHES_INTERVAL or, again, whenever new searches
# are added.
self.log.debug('Broadcaster search-retry thread has started.')
time_to_check_on_retirees = time.monotonic() + common.RETRY_RETIRED_SEARCHES_INTERVAL
interval = common.MIN_RETRY_SEARCHES_INTERVAL
while not self._close_event.is_set():
try:
self._searching_enabled.wait(0.5)
except TimeoutError:
# Here we go check on self._close_event before waiting again.
continue
t = time.monotonic()
# filter to just things that need to go out
def _construct_search_requests(items):
for search_id, it in items:
yield ca.SearchRequest(it[0], search_id,
ca.DEFAULT_PROTOCOL_VERSION)
# reset the last time this was sent
it[-2] = t
with self._search_lock:
if t >= time_to_check_on_retirees:
items = list(self.unanswered_searches.items())
time_to_check_on_retirees += common.RETRY_RETIRED_SEARCHES_INTERVAL
else:
# Skip over searches that haven't gotten any results in
# SEARCH_RETIREMENT_AGE.
items = list((search_id, it)
for search_id, it in self.unanswered_searches.items()
if (it[-1] > t))
# only send requests who we last sent at least interval in the past
resend_deadline = t - interval
items = [(sid, it) for sid, it in items if it[-2] < resend_deadline]
requests = _construct_search_requests(items)
if not self._searching_enabled.is_set():
continue
if items:
self.search_log.debug('Sending %d SearchRequests', len(items))
version_req = ca.VersionRequest(0, ca.DEFAULT_PROTOCOL_VERSION)
for batch in batch_requests(requests,
SEARCH_MAX_DATAGRAM_BYTES - len(version_req)):
self.send(version_req, *batch)
wait_time = max(0, interval - (time.monotonic() - t))
# Double the interval for the next loop.
interval = min(2 * interval, common.MAX_RETRY_SEARCHES_INTERVAL)
if self._search_now.wait(wait_time):
# New searches have been requested. Reset the interval between
# subseqent searches and force a check on the "retirees".
time_to_check_on_retirees = t
interval = common.MIN_RETRY_SEARCHES_INTERVAL
self._search_now.clear()
self.log.debug('Broadcaster search-retry thread has exited.')
def __del__(self):
try:
self.disconnect()
self.selector = None
except AttributeError:
pass
[docs]class Context:
"""
Encapsulates the state and connections of a client
Parameters
----------
broadcaster : SharedBroadcaster, optional
If None is specified, a fresh one is instantiated.
timeout : number or None, optional
Number of seconds before a CaprotoTimeoutError is raised. This default
can be overridden at the PV level or for any given operation. If unset,
the default is 2 seconds. If None, never timeout. A global timeout can
be specified via an environment variable ``CAPROTO_DEFAULT_TIMEOUT``.
host_name : string, optional
uses value of ``socket.gethostname()`` by default
client_name : string, optional
uses value of ``getpass.getuser()`` by default
max_workers : integer, optional
Number of worker threaders *per VirtualCircuit* for executing user
callbacks. Default is 1. For any number of workers, workers will
receive updates in the order which they are received from the server.
That is, work on each update will *begin* in sequential order.
Work-scheduling internal to the user callback is outside caproto's
control. If the number of workers is set to greater than 1, the work on
each update may not *finish* in a deterministic order. For example, if
workers are writing lines into a file, the only way to guarantee that
the lines are ordered properly is to use only one worker. If ordering
matters for your application, think carefully before increasing this
value from 1.
"""
def __init__(self, broadcaster=None, *,
timeout=common.GLOBAL_DEFAULT_TIMEOUT,
host_name=None, client_name=None, max_workers=1):
if broadcaster is None:
broadcaster = SharedBroadcaster()
self.broadcaster = broadcaster
self.timeout = timeout
if host_name is None:
host_name = socket.gethostname()
self.host_name = host_name
if client_name is None:
client_name = getpass.getuser()
self.max_workers = max_workers
self.client_name = client_name
self.log = logging.LoggerAdapter(
logging.getLogger('caproto.ctx'), {'role': 'CLIENT'})
self.pv_cache_lock = threading.RLock()
self.circuit_managers = {} # keyed on ((host, port), priority)
self._lock_during_get_circuit_manager = threading.RLock()
self.pvs = {} # (name, priority) -> pv
# name -> set of pvs --- with varied priority
self.pvs_needing_circuits = defaultdict(set)
self.broadcaster.add_listener(self)
self._search_results_queue = Queue()
# an event to close and clean up the whole context
self._close_event = threading.Event()
self.subscriptions_lock = threading.RLock()
self.subscriptions_to_activate = defaultdict(set)
self.activate_subscriptions_now = threading.Event()
self._process_search_results_thread = threading.Thread(
target=self._process_search_results_loop,
daemon=True, name='search')
self._process_search_results_thread.start()
self._activate_subscriptions_thread = threading.Thread(
target=self._activate_subscriptions,
daemon=True, name='activate_subscriptions')
self._activate_subscriptions_thread.start()
self.selector = SelectorThread(parent=self)
self.selector.start()
self._user_disconnected = False
def __repr__(self):
return (f"<Context "
f"searches_pending={len(self.broadcaster.unanswered_searches)} "
f"circuits={len(self.circuit_managers)} "
f"pvs={len(self.pvs)} "
f"idle={len([1 for pv in self.pvs.values() if pv._idle])}>")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.disconnect(wait=True)
[docs] def get_pvs(self, *names, priority=0, connection_state_callback=None,
access_rights_callback=None,
timeout=common.CONTEXT_DEFAULT_TIMEOUT):
"""
Return a list of PV objects.
These objects may not be connected at first. Channel creation occurs on
a background thread.
PVs are uniquely defined by their name and priority. If a PV with the
same name and priority is requested twice, the same (cached) object is
returned. Any callbacks included here are added to added alongside any
existing ones.
Parameters
----------
*names : strings
any number of PV names
priority : integer
Used by the server to triage subscription responses when under high
load. 0 is lowest; 99 is highest.
connection_state_callback : callable
Expected signature: ``f(pv, state)`` where ``pv`` is the instance
of ``PV`` whose state has changed and ``state`` is a string
access_rights_callback : callable
Expected signature: ``f(pv, access_rights)`` where ``pv`` is the
instance of ``PV`` whose state has changed and ``access_rights`` is
a member of the caproto ``AccessRights`` enum
timeout : number or None, optional
Number of seconds before a CaprotoTimeoutError is raised. This
default can be overridden for any specific operation. By default,
fall back to the default timeout set by the Context. If None, never
timeout.
"""
if self._user_disconnected:
raise ContextDisconnectedError("This Context is no longer usable.")
pvs = [] # list of all PV objects to return
names_to_search = [] # subset of names that we need to search for
for name in names:
with self.pv_cache_lock:
try:
pv = self.pvs[(name, priority)]
except KeyError:
pv = PV(name, priority, self, timeout)
names_to_search.append(name)
self.pvs[(name, priority)] = pv
self.pvs_needing_circuits[name].add(pv)
if connection_state_callback is not None:
pv.connection_state_callback.add_callback(
connection_state_callback, run=True)
if access_rights_callback is not None:
pv.access_rights_callback.add_callback(
access_rights_callback, run=True)
pvs.append(pv)
# TODO: potential bug?
# if callback is quick, is there a chance downstream listeners may
# never receive notification?
# Ask the Broadcaster to search for every PV for which we do not
# already have an instance. It might already have a cached search
# result, but that is the concern of broadcaster.search.
if names_to_search:
self.broadcaster.search(self._search_results_queue,
names_to_search)
return pvs
def reconnect(self, keys):
# We will reuse the same PV object but use a new cid.
names = []
pvs = []
for key in keys:
with self.pv_cache_lock:
pv = self.pvs[key]
pvs.append(pv)
name, _ = key
names.append(name)
# If there is a cached search result for this name, expire it.
with self.broadcaster._search_lock:
self.broadcaster.search_results.pop(name, None)
with self.pv_cache_lock:
self.pvs_needing_circuits[name].add(pv)
self.broadcaster.search(self._search_results_queue, names)
def _process_search_results_loop(self):
# Receive (address, (name1, name2, ...)). The sending side of this
# queue is held by SharedBroadcaster.command_loop.
self.log.debug('Context search-results processing loop has '
'started.')
while not self._close_event.is_set():
try:
address, names = self._search_results_queue.get(timeout=0.5)
except Empty:
# By restarting the loop, we will first check that we are not
# supposed to shut down the thread before we go back to
# waiting on the queue again.
continue
channels_grouped_by_circuit = defaultdict(list)
# Assign each PV a VirtualCircuitManager for managing a socket
# and tracking circuit state, as well as a ClientChannel for
# tracking channel state.
for name in names:
search_logger.debug('Connecting %s on circuit with %s:%d', name, *address,
extra={'pv': name,
'their_address': address,
'our_address': self.broadcaster.broadcaster.client_address,
'direction': '--->>>',
'role': 'CLIENT'})
# There could be multiple PVs with the same name and
# different priority. That is what we are looping over
# here. There could also be NO PVs with this name that need
# a circuit, because we could be receiving a duplicate
# search response (which we are supposed to ignore).
with self.pv_cache_lock:
pvs = self.pvs_needing_circuits.pop(name, set())
for pv in pvs:
# Get (make if necessary) a VirtualCircuitManager. This
# is where TCP socket creation happens.
cm = self.get_circuit_manager(address, pv.priority)
circuit = cm.circuit
pv.circuit_manager = cm
# TODO: NOTE: we are not following the suggestion to
# use the same cid as in the search. This simplifies
# things between the broadcaster and Context.
cid = cm.circuit.new_channel_id()
chan = ca.ClientChannel(name, circuit, cid=cid)
cm.channels[cid] = chan
cm.pvs[cid] = pv
channels_grouped_by_circuit[cm].append(chan)
pv.circuit_ready.set()
# Initiate channel creation with the server.
for cm, channels in channels_grouped_by_circuit.items():
commands = [chan.create() for chan in channels]
try:
cm.send(*commands)
except Exception:
if cm.dead.is_set():
self.log.debug("Circuit died while we were trying "
"to create the channel. We will "
"keep attempting this until it "
"works.")
# When the Context creates a new circuit, we will end
# up here again. No big deal.
continue
raise
self.log.debug('Context search-results processing thread has exited.')
def get_circuit_manager(self, address, priority):
"""
Return a VirtualCircuitManager for this address, priority. (It manages
a caproto.VirtualCircuit and a TCP socket.)
Make a new one if necessary.
"""
with self._lock_during_get_circuit_manager:
cm = self.circuit_managers.get((address, priority), None)
if cm is None or cm.dead.is_set():
version = self.broadcaster.server_protocol_versions[address]
circuit = ca.VirtualCircuit(
our_role=ca.CLIENT,
address=address,
priority=priority,
protocol_version=version)
cm = VirtualCircuitManager(self, circuit, self.selector)
self.circuit_managers[(address, priority)] = cm
return cm
def _activate_subscriptions(self):
while not self._close_event.is_set():
t = time.monotonic()
with self.subscriptions_lock:
items = list(self.subscriptions_to_activate.items())
self.subscriptions_to_activate.clear()
for cm, subs in items:
def requests():
"Yield EventAddRequest commands."
for sub in subs:
command = sub.compose_command()
# compose_command() returns None if this
# Subscription is inactive (meaning there are no
# user callbacks attached). It will send an
# EventAddRequest on its own if/when the user does
# add any callbacks, so we can skip it here.
if command is not None:
yield command
for batch in batch_requests(requests(),
common.EVENT_ADD_BATCH_MAX_BYTES):
try:
cm.send(*batch)
except Exception:
if cm.dead.is_set():
self.log.debug("Circuit died while we were "
"trying to activate "
"subscriptions. We will "
"keep attempting this until it "
"works.")
# When the Context creates a new circuit, we will
# end up here again. No big deal.
break
wait_time = max(0, (common.RESTART_SUBS_PERIOD -
(time.monotonic() - t)))
self.activate_subscriptions_now.wait(wait_time)
self.activate_subscriptions_now.clear()
self.log.debug('Context restart-subscriptions thread exiting')
def disconnect(self, *, wait=True):
self._user_disconnected = True
try:
self._close_event.set()
# disconnect any circuits we have
circuits = list(self.circuit_managers.values())
total_circuits = len(circuits)
disconnected = False
for idx, circuit in enumerate(circuits, 1):
if circuit.connected:
self.log.debug('Disconnecting circuit %d/%d: %s',
idx, total_circuits, circuit)
circuit.disconnect()
disconnected = True
if disconnected:
self.log.debug('All circuits disconnected')
finally:
# Remove from Broadcaster.
self.broadcaster.remove_listener(self)
# clear any state about circuits and search results
self.log.debug('Clearing circuit managers')
self.circuit_managers.clear()
self.log.debug("Stopping SelectorThread of the context")
self.selector.stop()
if wait:
self._process_search_results_thread.join()
self._activate_subscriptions_thread.join()
self.selector.thread.join()
self.log.debug('Context disconnection complete')
def __del__(self):
try:
self.disconnect(wait=False)
except Exception:
...
finally:
self.selector = None
self.broadcaster = None
self.circuit_managers = None
[docs]class VirtualCircuitManager:
"""
Encapsulates a VirtualCircuit, a TCP socket, and additional state
This object should never be instantiated directly by user code. It is used
internally by the Context. Its methods may be touched by user code, but
this is rarely necessary.
"""
__slots__ = ('context', 'circuit', 'channels', 'ioids', '_ioid_counter',
'subscriptions', '_ready', 'log',
'socket', 'selector', 'pvs', 'all_created_pvnames',
'dead', 'process_queue', 'processing',
'_subscriptionid_counter', 'user_callback_executor',
'last_tcp_receipt', '__weakref__', '_tags')
def __init__(self, context, circuit, selector,
timeout=common.GLOBAL_DEFAULT_TIMEOUT):
self.context = context
self.circuit = circuit # a caproto.VirtualCircuit
self.log = circuit.log
self.channels = {} # map cid to Channel
self.pvs = {} # map cid to PV
self.ioids = {} # map ioid to Channel and info dict
self.subscriptions = {} # map subscriptionid to Subscription
self.socket = None
self.selector = selector
self.user_callback_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.context.max_workers,
thread_name_prefix='user-callback-executor')
self.last_tcp_receipt = None
# keep track of all PV names that are successfully connected to within
# this circuit. This is to be cleared upon disconnection:
self.all_created_pvnames = []
self.dead = threading.Event()
self._ioid_counter = ThreadsafeCounter()
self._subscriptionid_counter = ThreadsafeCounter()
self._ready = threading.Event()
# Connect.
if self.circuit.states[ca.SERVER] is ca.IDLE:
self.socket = socket.create_connection(self.circuit.address)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.circuit.our_address = self.socket.getsockname()
# This dict is passed to the loggers.
self._tags = {'their_address': self.circuit.address,
'our_address': self.circuit.our_address,
'direction': '<<<---',
'role': repr(self.circuit.our_role)}
self.selector.add_socket(self.socket, self)
self.send(ca.VersionRequest(self.circuit.priority,
ca.DEFAULT_PROTOCOL_VERSION),
ca.HostNameRequest(self.context.host_name),
ca.ClientNameRequest(self.context.client_name),
extra=self._tags)
else:
raise CaprotoRuntimeError("Cannot connect. States are {} "
"".format(self.circuit.states))
# Old versions of the protocol do not send a VersionResponse at TCP
# connection time, so set this Event manually rather than waiting for
# it to be set by receipt of a VersionResponse.
if self.context.broadcaster.server_protocol_versions[self.circuit.address] < 12:
self._ready.set()
ready = self._ready.wait(timeout=timeout)
if not ready:
host, port = self.circuit.address
raise CaprotoTimeoutError(f"Circuit with server at {host}:{port} "
f"did not connect within "
f"{float(timeout):.3}-second timeout.")
def __repr__(self):
return (f"<VirtualCircuitManager circuit={self.circuit} "
f"pvs={len(self.pvs)} ioids={len(self.ioids)} "
f"subscriptions={len(self.subscriptions)}>")
@property
def connected(self):
return self.circuit.states[ca.CLIENT] is ca.CONNECTED
[docs] def send(self, *commands, extra=None):
# Turn the crank: inform the VirtualCircuit that these commands will
# be send, and convert them to buffers.
sock = self.socket
if sock is not None:
buffers_to_send = self.circuit.send(*commands, extra=extra)
sock.sendall(b"".join(buffers_to_send))
[docs] def received(self, bytes_recv, address):
"""Receive and process and next command from the virtual circuit.
This will be run on the recv thread"""
self.last_tcp_receipt = time.monotonic()
commands, num_bytes_needed = self.circuit.recv(bytes_recv)
for c in commands:
self._process_command(c)
if not bytes_recv:
# Tell the selector to remove our socket
return ca.DISCONNECTED
return num_bytes_needed
[docs] def events_off(self):
"""
Suspend updates to all subscriptions on this circuit.
This may be useful if the server produces updates faster than the
client can processs them.
"""
self.send(ca.EventsOffRequest())
[docs] def events_on(self):
"""
Reactive updates to all subscriptions on this circuit.
"""
self.send(ca.EventsOnRequest())
def _process_command(self, command):
try:
self.circuit.process_command(command)
except ca.CaprotoError as ex:
if hasattr(ex, 'channel'):
channel = ex.channel
self.log.warning('Invalid command %s for Channel %s in state %s',
command, channel, channel.states,
exc_info=ex)
# channel exceptions are not fatal
return
else:
self.log.error('Invalid command %s for VirtualCircuit %s in '
'state %s', command, self, self.circuit.states,
exc_info=ex)
# circuit exceptions are fatal; exit the loop
self.disconnect()
return
tags = self._tags
if command is ca.DISCONNECTED:
self._disconnected()
elif isinstance(command, (ca.VersionResponse,)):
assert self.connected # double check that the state machine agrees
self._ready.set()
elif isinstance(command, (ca.ReadNotifyResponse,
ca.ReadResponse,
ca.WriteNotifyResponse)):
ioid_info = self.ioids.pop(command.ioid)
deadline = ioid_info['deadline']
pv = ioid_info['pv']
tags = tags.copy()
tags['pv'] = pv.name
if deadline is not None and time.monotonic() > deadline:
self.log.warning("Ignoring late response with ioid=%d regarding "
"PV named %s because "
"it arrived %.3f seconds after the deadline "
"specified by the timeout.", command.ioid,
pv.name, time.monotonic() - deadline)
return
event = ioid_info.get('event')
if event is not None:
# If PV.read() or PV.write() are waiting on this response,
# they hold a reference to ioid_info. We will use that to
# provide the response to them and then set the Event that they
# are waiting on.
ioid_info['response'] = command
event.set()
callback = ioid_info.get('callback')
if callback is not None:
try:
self.user_callback_executor.submit(callback, command)
except RuntimeError:
if self.dead.is_set():
# if we are trying to process updates while
# shutting down the submit will fail. In that
# case we should drop the exception on the
# floor and move on
return
# otherwise raise and let someone else deal with
# the mess
raise
elif isinstance(command, ca.EventAddResponse):
try:
sub = self.subscriptions[command.subscriptionid]
except KeyError:
# This subscription has been removed. We assume that this
# response was in flight before the server processed our
# unsubscription.
pass
else:
# This method submits jobs to the Contexts's
# ThreadPoolExecutor for user callbacks.
sub.process(command)
tags = tags.copy()
tags['pv'] = sub.pv.name
elif isinstance(command, ca.AccessRightsResponse):
pv = self.pvs[command.cid]
pv.access_rights_changed(command.access_rights)
tags = tags.copy()
tags['pv'] = pv.name
elif isinstance(command, ca.EventCancelResponse):
# TODO Any way to add the pv name to tags here?
...
elif isinstance(command, ca.CreateChanResponse):
pv = self.pvs[command.cid]
chan = self.channels[command.cid]
self.all_created_pvnames.append(pv.name)
with pv.component_lock:
pv.channel = chan
pv.channel_ready.set()
pv.connection_state_changed('connected', chan)
tags = tags.copy()
tags['pv'] = pv.name
elif isinstance(command, (ca.ServerDisconnResponse,
ca.ClearChannelResponse)):
pv = self.pvs[command.cid]
pv.connection_state_changed('disconnected', None)
tags = tags.copy()
tags['pv'] = pv.name
# NOTE: pv remains valid until server goes down
elif isinstance(command, ca.EchoResponse):
# The important effect here is that it will have updated
# self.last_tcp_receipt when the bytes flowed through
# self.received.
...
if isinstance(command, ca.Message):
tags['bytesize'] = len(command)
self.log.debug("%r", command, extra=tags)
def _disconnected(self, *, reconnect=True):
# Ensure that this method is idempotent.
if self.dead.is_set():
return
tags = {'their_address': self.circuit.address}
self.log.debug('Virtual circuit with address %s:%d has disconnected.',
*self.circuit.address, extra=tags)
# Update circuit state. This will be reflected on all PVs, which
# continue to hold a reference to this disconnected circuit.
self.circuit.disconnect()
for pv in self.pvs.values():
pv.channel_ready.clear()
pv.circuit_ready.clear()
self.dead.set()
for ioid_info in self.ioids.values():
# Un-block any calls to PV.read() or PV.write() that are waiting on
# responses that we now know will never arrive. They will check on
# circuit health and raise appropriately.
event = ioid_info.get('event')
if event is not None:
event.set()
with self.context.broadcaster._search_lock:
for n in self.all_created_pvnames:
self.context.broadcaster.search_results.pop(n, None)
self.all_created_pvnames.clear()
for pv in self.pvs.values():
pv.connection_state_changed('disconnected', None)
# Remove VirtualCircuitManager from Context.
# This will cause all future calls to Context.get_circuit_manager()
# to create a fresh VirtualCiruit and VirtualCircuitManager.
self.context.circuit_managers.pop(self.circuit.address, None)
# Clean up the socket if it has not yet been cleared:
sock, self.socket = self.socket, None
if sock is not None:
self.selector.remove_socket(sock)
try:
sock.shutdown(socket.SHUT_WR)
except OSError:
pass
sock.close()
tags = {'their_address': self.circuit.address}
if reconnect:
# Kick off attempt to reconnect all PVs via fresh circuit(s).
self.log.debug('Kicking off reconnection attempts for %d PVs '
'disconnected from %s:%d....',
len(self.channels), *self.circuit.address, extra=tags)
self.context.reconnect(((chan.name, chan.circuit.priority)
for chan in self.channels.values()))
else:
self.log.debug('Not attempting reconnection', extra=tags)
[docs] def disconnect(self):
self._disconnected(reconnect=False)
self.log.debug("Shutting down ThreadPoolExecutor for user callbacks",
extra={'their_address': self.circuit.address})
self.user_callback_executor.shutdown()
if self.socket is None:
return
self.log.debug('Circuit manager disconnected by user')
def __del__(self):
try:
self._disconnected(reconnect=False)
except AttributeError:
pass
[docs]class PV:
"""
Represents one PV, specified by a name and priority.
This object may exist prior to connection and persists across any
subsequent re-connections.
This object should never be instantiated directly by user code; rather it
should be created by calling the ``get_pvs`` method on a ``Context``
object.
"""
__slots__ = ('name', 'priority', 'context', '_circuit_manager', '_channel',
'circuit_ready', 'channel_ready', 'access_rights',
'access_rights_callback', 'subscriptions',
'command_bundle_queue', 'component_lock', '_idle', '_in_use',
'_usages', 'connection_state_callback', 'log',
'_timeout',
'__weakref__')
def __init__(self, name, priority, context, timeout):
"""
These must be instantiated by a Context, never directly.
"""
self.name = name
self.priority = priority
self.context = context
self.access_rights = None # will be overwritten with AccessRights
self.log = logging.LoggerAdapter(ch_logger, {'pv': self.name, 'role': 'CLIENT'})
# Use this lock whenever we touch circuit_manager or channel.
self.component_lock = threading.RLock()
self.circuit_ready = threading.Event()
self.channel_ready = threading.Event()
self.connection_state_callback = CallbackHandler(self)
self.access_rights_callback = CallbackHandler(self)
self._timeout = timeout
self._circuit_manager = None
self._channel = None
self.subscriptions = {}
self._idle = False
self._in_use = threading.Condition()
self._usages = 0
@property
def timeout(self):
"""
Effective default timeout.
Valid values are:
* CONTEXT_DEFAULT_TIMEOUT (fall back to Context.timeout)
* a floating-point number
* None (never timeout)
"""
if self._timeout is common.CONTEXT_DEFAULT_TIMEOUT:
return self.context.timeout
else:
return self._timeout
@timeout.setter
def timeout(self, val):
self._timeout = val
@property
def circuit_manager(self):
return self._circuit_manager
@circuit_manager.setter
def circuit_manager(self, val):
with self.component_lock:
self._circuit_manager = val
@property
def channel(self):
return self._channel
@channel.setter
def channel(self, val):
with self.component_lock:
self._channel = val
[docs] def access_rights_changed(self, rights):
self.access_rights = rights
self.access_rights_callback.process(self, rights)
[docs] def connection_state_changed(self, state, channel):
self.log.info('connection state changed to %s.', state)
self.connection_state_callback.process(self, state)
if state == 'disconnected':
for sub in self.subscriptions.values():
with sub.callback_lock:
if sub.callbacks:
sub.needs_reactivation = True
if state == 'connected':
cm = self.circuit_manager
ctx = cm.context
with ctx.subscriptions_lock:
for sub in self.subscriptions.values():
with sub.callback_lock:
if sub.needs_reactivation:
ctx.subscriptions_to_activate[cm].add(sub)
sub.needs_reactivation = False
def __repr__(self):
if self._idle:
state = "(idle)"
elif self.circuit_manager is None or self.circuit_manager.dead.is_set():
state = "(searching....)"
else:
state = (f"address={self.circuit_manager.circuit.address}, "
f"circuit_state="
f"{self.circuit_manager.circuit.states[ca.CLIENT]}")
if self.connected:
state += f", channel_state={self.channel.states[ca.CLIENT]}"
else:
state += " (creating...)"
return f"<PV name={self.name!r} priority={self.priority} {state}>"
@property
def connected(self):
channel = self.channel
if channel is None:
return False
return channel.states[ca.CLIENT] is ca.CONNECTED
[docs] def wait_for_search(self, *, timeout=common.PV_DEFAULT_TIMEOUT):
"""
Wait for this PV to be found.
This does not wait for the PV's Channel to be created; it merely waits
for an address (and a VirtualCircuit) to be assigned.
Parameters
----------
timeout : number or None, optional
Seconds to wait before a CaprotoTimeoutError is raised. Default is
``PV.timeout``, which falls back to Context.timeout if not set. If
None, never timeout.
"""
if timeout is common.PV_DEFAULT_TIMEOUT:
timeout = self.timeout
if not self.circuit_ready.wait(timeout=timeout):
raise CaprotoTimeoutError("No servers responded to a search for a "
"channel named {!r} within {:.3}-second "
"timeout."
"".format(self.name, float(timeout)))
[docs] @ensure_connected
def wait_for_connection(self, *, timeout=common.PV_DEFAULT_TIMEOUT):
"""
Wait for this PV to be connected.
Parameters
----------
timeout : number or None, optional
Seconds to wait before a CaprotoTimeoutError is raised. Default is
``PV.timeout``, which falls back to ``PV.context.timeout`` if not
set. If None, never timeout.
"""
pass
[docs] def go_idle(self):
"""Request to clear this Channel to reduce load on client and server.
A new Channel will be automatically, silently created the next time any
method requiring a connection is called. Thus, this saves some memory
in exchange for making the next request a bit slower, as it has to
redo the handshake with the server first.
If there are any subscriptions with callbacks, this request will be
ignored. If the PV is in the process of connecting, this request will
be ignored. If there are any actions in progress (read, write) this
request will be processed when they are complete.
"""
for sub in self.subscriptions.values():
if sub.callbacks:
return
with self._in_use:
if not self.channel_ready.is_set():
return
# Wait until no other methods that employ @self.ensure_connected
# are in process.
self._in_use.wait_for(lambda: self._usages == 0)
# No other threads are using the connection, and we are holding the
# self._in_use Condition's lock, so we can safely close the
# connection. The next thread to acquire the lock will re-connect
# after it acquires the lock.
try:
self.channel_ready.clear()
self.circuit_manager.send(self.channel.clear(),
extra={'pv': self.name})
except OSError:
# the socket is dead-dead, do nothing
...
self._idle = True
[docs] @ensure_connected
def read(self, *, wait=True, callback=None,
timeout=common.PV_DEFAULT_TIMEOUT,
data_type=None, data_count=None, notify=True):
"""Request a fresh reading.
Can do one or both of:
- Block while waiting for the response, and return it.
- Pass the response to callback, with or without blocking.
Parameters
----------
wait : boolean
If True (default) block until a matching response is
received from the server. Raises CaprotoTimeoutError if that
response is not received within the time specified by the `timeout`
parameter.
callback : callable or None
Called with the response as its argument when received.
timeout : number or None, optional
Seconds to wait before a CaprotoTimeoutError is raised. Default is
``PV.timeout``, which falls back to ``PV.context.timeout`` if not
set. If None, never timeout.
data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional
Request specific data type or a class of data types, matched to the
channel's native data type. Default is Channel's native data type.
data_count : integer, optional
Requested number of values. Default is the channel's native data
count.
notify: boolean, optional
Send a ``ReadNotifyRequest`` instead of a ``ReadRequest``. True by
default.
"""
if timeout is common.PV_DEFAULT_TIMEOUT:
timeout = self.timeout
cm, chan = self._circuit_manager, self._channel
ioid = cm._ioid_counter()
command = chan.read(ioid=ioid, data_type=data_type,
data_count=data_count, notify=notify)
# Stash the ioid to match the response to the request.
event = threading.Event()
ioid_info = dict(event=event, pv=self, request=command)
if callback is not None:
ioid_info['callback'] = callback
cm.ioids[ioid] = ioid_info
deadline = time.monotonic() + timeout if timeout is not None else None
ioid_info['deadline'] = deadline
cm.send(command, extra={'pv': self.name})
if not wait:
return
# The circuit_manager will put a reference to the response into
# ioid_info and then set event.
if not event.wait(timeout=timeout):
host, port = cm.circuit.address
raise CaprotoTimeoutError(
f"Server at {host}:{port} did "
f"not respond to attempt to read channel named "
f"{self.name!r} within {float(timeout):.3}-second timeout. "
f"The ioid of the expected response is {ioid}."
)
if cm.dead.is_set():
# This circuit has died sometime during this function call.
# The exception raised here will be caught by
# @ensure_connected, which will retry the function call a
# in hopes of getting a working circuit until our `timeout` has
# been used up.
raise DeadCircuitError()
return ioid_info['response']
[docs] @ensure_connected
def write(self, data, *, wait=True, callback=None,
timeout=common.PV_DEFAULT_TIMEOUT,
notify=None, data_type=None, data_count=None):
"""
Write a new value. Optionally, request confirmation from the server.
Can do one or both of:
- Block while waiting for the response, and return it.
- Pass the response to callback, with or without blocking.
Parameters
----------
data : str, int, or float or any Iterable of these
Value(s) to write.
wait : boolean
If True (default) block until a matching WriteNotifyResponse is
received from the server. Raises CaprotoTimeoutError if that
response is not received within the time specified by the `timeout`
parameter.
callback : callable or None
Called with the WriteNotifyResponse as its argument when received.
timeout : number or None, optional
Seconds to wait before a CaprotoTimeoutError is raised. Default is
``PV.timeout``, which falls back to ``PV.context.timeout`` if not
set. If None, never timeout.
notify : boolean or None
If None (default), set to True if wait=True or callback is set.
Can be manually set to True or False. Will raise ValueError if set
to False while wait=True or callback is set.
data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional
Write specific data type or a class of data types, matched to the
channel's native data type. Default is Channel's native data type.
data_count : integer, optional
Requested number of values. Default is the channel's native data
count.
"""
if timeout is common.PV_DEFAULT_TIMEOUT:
timeout = self.timeout
cm, chan = self._circuit_manager, self._channel
if notify is None:
notify = (wait or callback is not None)
ioid = cm._ioid_counter()
command = chan.write(data, ioid=ioid, notify=notify,
data_type=data_type, data_count=data_count)
if notify:
event = threading.Event()
ioid_info = dict(event=event, pv=self, request=command)
if callback is not None:
ioid_info['callback'] = callback
cm.ioids[ioid] = ioid_info
deadline = time.monotonic() + timeout if timeout is not None else None
ioid_info['deadline'] = deadline
# do not need to lock this, locking happens in circuit command
else:
if wait or callback is not None:
raise CaprotoValueError("Must set notify=True in order to use "
"`wait` or `callback` because, without a "
"notification of 'put-completion' from the "
"server, there is nothing to wait on or to "
"trigger a callback.")
cm.send(command, extra={'pv': self.name})
if not wait:
return
# The circuit_manager will put a reference to the response into
# ioid_info and then set event.
if not event.wait(timeout=timeout):
if cm.dead.is_set():
# This circuit has died sometime during this function call.
# The exception raised here will be caught by
# @ensure_connected, which will retry the function call a
# in hopes of getting a working circuit until our `timeout` has
# been used up.
raise DeadCircuitError()
host, port = cm.circuit.address
raise CaprotoTimeoutError(
f"Server at {host}:{port} did "
f"not respond to attempt to write to channel named "
f"{self.name!r} within {float(timeout):.3}-second timeout. "
f"The ioid of the expected response is {ioid}."
)
return ioid_info['response']
[docs] def subscribe(self, data_type=None, data_count=None,
low=0.0, high=0.0, to=0.0, mask=None):
"""
Start a new subscription to which user callback may be added.
Parameters
----------
data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional
Request specific data type or a class of data types, matched to the
channel's native data type. Default is Channel's native data type.
data_count : integer, optional
Requested number of values. Default is the channel's native data
count.
low, high, to : float, optional
deprecated by Channel Access, not yet implemented by caproto
mask : SubscriptionType, optional
Subscribe to selective updates.
Returns
-------
subscription : Subscription
Examples
--------
Define a subscription.
>>> sub = pv.subscribe()
Add a user callback. The subscription will be transparently activated
(i.e. an ``EventAddRequest`` will be sent) when the first user callback
is added.
>>> sub.add_callback(my_func)
Multiple callbacks may be added to the same subscription.
>>> sub.add_callback(another_func)
See the docstring for :class:`Subscription` for more.
"""
# A Subscription is uniquely identified by the Signature created by its
# args and kwargs.
bound = SUBSCRIBE_SIG.bind(data_type, data_count, low, high, to, mask)
key = tuple(bound.arguments.items())
try:
sub = self.subscriptions[key]
except KeyError:
sub = Subscription(self,
data_type, data_count,
low, high, to, mask)
self.subscriptions[key] = sub
# The actual EPICS messages will not be sent until the user adds
# callbacks via sub.add_callback(user_func).
return sub
[docs] def unsubscribe_all(self):
"Clear all subscriptions. (Remove all user callbacks from them.)"
for sub in self.subscriptions.values():
sub.clear()
[docs] @ensure_connected
def time_since_last_heard(self, timeout=common.PV_DEFAULT_TIMEOUT):
"""
Seconds since last message from the server that provides this channel.
The time is reset to 0 whenever we receive a TCP message related to
user activity *or* a Beacon. Servers are expected to send Beacons at
regular intervals. If we do not receive either a Beacon or TCP message,
we initiate an Echo over TCP, to which the server is expected to
promptly respond.
Therefore, the time reported here should not much exceed
``EPICS_CA_CONN_TMO`` (default 30 seconds unless overriden by that
environment variable) if the server is healthy.
If the server fails to send a Beacon on schedule *and* fails to reply to
an Echo, the server is assumed dead. A warning is issued, and all PVs
are disconnected to initiate a reconnection attempt.
Parameters
----------
timeout : number or None, optional
Seconds to wait before a CaprotoTimeoutError is raised. Default is
``PV.timeout``, which falls back to ``PV.context.timeout`` if not
set. If None, never timeout.
"""
address = self.circuit_manager.circuit.address
return self.context.broadcaster.time_since_last_heard()[address]
# def __hash__(self):
# return id((self.context, self.circuit_manager, self.name))
class CallbackHandler:
def __init__(self, pv):
# NOTE: not a WeakValueDictionary or WeakSet as PV is unhashable...
self.callbacks = {}
self.pv = pv
self._callback_id = 0
self.callback_lock = threading.RLock()
self._last_call_values = None
def add_callback(self, func, run=False):
def removed(_):
self.remove_callback(cb_id) # defined below inside the lock
if inspect.ismethod(func):
ref = weakref.WeakMethod(func, removed)
else:
# TODO: strong reference to non-instance methods?
ref = weakref.ref(func, removed)
with self.callback_lock:
cb_id = self._callback_id
self._callback_id += 1
self.callbacks[cb_id] = ref
if run and self._last_call_values is not None:
with self.callback_lock:
args, kwargs = self._last_call_values
self.process(*args, **kwargs)
return cb_id
def remove_callback(self, token):
with self.callback_lock:
self.callbacks.pop(token, None)
def process(self, *args, **kwargs):
"""
This is a fast operation that submits jobs to the Context's
ThreadPoolExecutor and then returns.
"""
to_remove = []
with self.callback_lock:
callbacks = list(self.callbacks.items())
self._last_call_values = (args, kwargs)
for cb_id, ref in callbacks:
callback = ref()
if callback is None:
to_remove.append(cb_id)
continue
try:
self.pv.circuit_manager.user_callback_executor.submit(
callback, *args, **kwargs)
except RuntimeError:
if self.pv.circuit_manager.dead.is_set():
# if the circuit is dead, so is the executor forgive
# and exit
return
# otherwise raise and let someone else deal with the
# mess
raise
with self.callback_lock:
for remove_id in to_remove:
self.callbacks.pop(remove_id, None)
[docs]class Subscription(CallbackHandler):
"""
Represents one subscription, specified by a PV and configurational parameters
It may fan out to zero, one, or multiple user-registered callback
functions.
This object should never be instantiated directly by user code; rather
it should be made by calling the ``subscribe()`` method on a ``PV`` object.
"""
def __init__(self, pv, data_type, data_count, low, high, to, mask):
super().__init__(pv)
# Stash everything, but do not send any EPICS messages until the first
# user callback is attached.
self.data_type = data_type
self.data_count = data_count
self.low = low
self.high = high
self.to = to
self.mask = mask
self.subscriptionid = None
self.most_recent_response = None
self.needs_reactivation = False
# This is related to back-compat for user callbacks that have the old
# signature, f(response).
self.__wrapper_weakrefs = set()
@property
def log(self):
return self.pv.log
def __repr__(self):
return f"<Subscription to {self.pv.name!r}, id={self.subscriptionid}>"
def _subscribe(self, timeout=common.PV_DEFAULT_TIMEOUT):
"""This is called automatically after the first callback is added.
"""
cm = self.pv.circuit_manager
if cm is None:
# We are currently disconnected (perhaps have not yet connected).
# When the PV connects, this subscription will be added.
with self.callback_lock:
self.needs_reactivation = True
else:
# We are (or very recently were) connected. In the rare event
# where cm goes dead in the interim, subscription will be retried
# by the activation loop.
ctx = cm.context
with ctx.subscriptions_lock:
ctx.subscriptions_to_activate[cm].add(self)
ctx.activate_subscriptions_now.set()
@ensure_connected
def compose_command(self, timeout=common.PV_DEFAULT_TIMEOUT):
"This is used by the Context to re-subscribe in bulk after dropping."
with self.callback_lock:
if not self.callbacks:
return None
cm, chan = self.pv._circuit_manager, self.pv._channel
subscriptionid = cm._subscriptionid_counter()
command = chan.subscribe(data_type=self.data_type,
data_count=self.data_count, low=self.low,
high=self.high, to=self.to,
mask=self.mask,
subscriptionid=subscriptionid)
subscriptionid = command.subscriptionid
self.subscriptionid = subscriptionid
# The circuit_manager needs to know the subscriptionid so that it can
# route responses to this request.
cm.subscriptions[subscriptionid] = self
return command
[docs] def clear(self):
"""
Remove all callbacks.
"""
with self.callback_lock:
for cb_id in list(self.callbacks):
self.remove_callback(cb_id)
# Once self.callbacks is empty, self.remove_callback calls
# self._unsubscribe for us.
def _unsubscribe(self, timeout=common.PV_DEFAULT_TIMEOUT):
"""
This is automatically called if the number of callbacks goes to 0.
"""
with self.callback_lock:
if self.subscriptionid is None:
# Already unsubscribed.
return
subscriptionid = self.subscriptionid
self.subscriptionid = None
self.most_recent_response = None
self.pv.circuit_manager.subscriptions.pop(subscriptionid, None)
chan = self.pv.channel
if chan and chan.states[ca.CLIENT] is ca.CONNECTED:
try:
command = self.pv.channel.unsubscribe(subscriptionid)
except ca.CaprotoKeyError:
pass
else:
self.pv.circuit_manager.send(command, extra={'pv': self.pv.name})
def process(self, command):
# TODO here i think we can decouple PV update rates and callback
# handling rates, if desirable, to not bog down performance.
# As implemented below, updates are blocking further messages from
# the CA servers from processing. (-> ThreadPool, etc.)
pv = self.pv
super().process(self, command)
self.log.debug("%r: %r", pv.name, command)
self.most_recent_response = command
[docs] def add_callback(self, func):
"""
Add a callback to receive responses.
Parameters
----------
func : callable
Expected signature: ``func(sub, response)``.
The signature ``func(response)`` is also supported for
backward-compatibility but will issue warnings. Support will be
removed in a future release of caproto.
Returns
-------
token : int
Integer token that can be passed to :meth:`remove_callback`.
.. versionchanged:: 0.5.0
Changed the expected signature of ``func`` from ``func(response)``
to ``func(sub, response)``.
"""
# Handle func with signature func(response) for back-compat.
func = adapt_old_callback_signature(func, self.__wrapper_weakrefs)
with self.callback_lock:
was_empty = not self.callbacks
cb_id = super().add_callback(func)
most_recent_response = self.most_recent_response
if was_empty:
# This is the first callback. Set up a subscription, which
# should elicit a response from the server soon giving the
# current value to this func (and any other funcs added in the
# mean time).
self._subscribe()
else:
# This callback is piggy-backing onto an existing subscription.
# Send it the most recent response, unless we are still waiting
# for that first response from the server.
if most_recent_response is not None:
try:
func(self, most_recent_response)
except Exception:
self.log.exception(
"Exception raised during processing most recent "
"response %r with new callback %r",
most_recent_response, func)
return cb_id
[docs] def remove_callback(self, token):
"""
Remove callback using token that was returned by :meth:`add_callback`.
Parameters
----------
token : integer
Token returned by :meth:`add_callback`.
"""
with self.callback_lock:
super().remove_callback(token)
if not self.callbacks:
# Go dormant.
self._unsubscribe()
self.most_recent_response = None
self.needs_reactivation = False
def __del__(self):
try:
self.clear()
except TimeoutError:
pass
[docs]class Batch:
"""
Accumulate requests and then issue them all in batch.
Parameters
----------
timeout : number or None
Number of seconds to wait before ignoring late responses. Default
is 2.
Examples
--------
Read some PVs in batch and stash the readings in a dictionary as they
come in.
>>> results = {}
>>> def stash_result(name, response):
... results[name] = response.data
...
>>> with Batch() as b:
... for pv in pvs:
... b.read(pv, functools.partial(stash_result, pv.name))
... # The requests are sent upon exiting this 'with' block.
...
The ``results`` dictionary will be populated as responses come in.
"""
def __init__(self, timeout=2):
self.timeout = timeout
self._commands = defaultdict(list) # map each circuit to commands
self._ioid_infos = []
def __enter__(self):
return self
[docs] def read(self, pv, callback, data_type=None, data_count=None):
"""Request a fresh reading as part of a batched request.
Notice that, unlike :meth:`PV.read`, the callback is required. (There
is no other way to get the result back from a batched read.)
Parameters
----------
pv : PV
callback : callable
Expected signature: ``f(response)``
data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional
Request specific data type or a class of data types, matched to the
channel's native data type. Default is Channel's native data type.
data_count : integer, optional
Requested number of values. Default is the channel's native data
count.
"""
ioid = pv.circuit_manager._ioid_counter()
command = pv.channel.read(ioid=ioid,
data_type=data_type,
data_count=data_count,
notify=True)
self._commands[pv.circuit_manager].append(command)
# Stash the ioid to match the response to the request.
# The request is used in the logging in __exit__. It is not needed
# by the circuit.
ioid_info = dict(callback=callback, pv=pv, request=command)
pv.circuit_manager.ioids[ioid] = ioid_info
self._ioid_infos.append(ioid_info)
[docs] def write(self, pv, data, callback=None, data_type=None, data_count=None):
"""Write a new value as part of a batched request.
Parameters
----------
pv : PV
data : str, int, or float or any Iterable of these
Value(s) to write.
callback : callable
Expected signature: ``f(response)``
data_type : {'native', 'status', 'time', 'graphic', 'control'} or ChannelType or int ID, optional
Request specific data type or a class of data types, matched to the
channel's native data type. Default is Channel's native data type.
data_count : integer, optional
Requested number of values. Default is the channel's native data
count.
"""
ioid = pv.circuit_manager._ioid_counter()
command = pv.channel.write(data=data,
ioid=ioid,
data_type=data_type,
data_count=data_count,
notify=callback is not None)
self._commands[pv.circuit_manager].append(command)
if callback:
# Stash the ioid to match the response to the request.
# The request is used in the logging in __exit__. It is not needed
# by the circuit.
ioid_info = dict(callback=callback, pv=pv, request=command)
pv.circuit_manager.ioids[ioid] = ioid_info
self._ioid_infos.append(ioid_info)
def __exit__(self, exc_type, exc_value, traceback):
timeout = self.timeout
deadline = time.monotonic() + timeout if timeout is not None else None
for ioid_info in self._ioid_infos:
ioid_info['deadline'] = deadline
for circuit_manager, commands in self._commands.items():
circuit_manager.send(*commands)
# The signature of caproto._circuit.ClientChannel.subscribe, which is used to
# resolve the (args, kwargs) of a Subscription into a unique key.
SUBSCRIBE_SIG = Signature([
Parameter('data_type', Parameter.POSITIONAL_OR_KEYWORD, default=None),
Parameter('data_count', Parameter.POSITIONAL_OR_KEYWORD, default=None),
Parameter('low', Parameter.POSITIONAL_OR_KEYWORD, default=0),
Parameter('high', Parameter.POSITIONAL_OR_KEYWORD, default=0),
Parameter('to', Parameter.POSITIONAL_OR_KEYWORD, default=0),
Parameter('mask', Parameter.POSITIONAL_OR_KEYWORD, default=None)])