Source code for caproto.pva.sync.client

import collections
import copy
import dataclasses
import getpass
import logging
import socket
import time
import typing
from typing import Dict, Tuple

# REBASE TODO this will go away - need to rebase
from caproto import (MAX_UDP_RECV, bcast_socket, get_client_address_list,
                     get_environment_variables, pva)
from caproto.pva import (CLIENT, CONNECTED, DISCONNECTED, NEED_DATA,
                         AddressTuple, Broadcaster, CaprotoError,
                         ChannelFieldInfoResponse, ChannelGetResponse,
                         ChannelMonitorResponse, ChannelPutResponse,
                         ClientChannel, ClientVirtualCircuit,
                         ConnectionValidatedResponse,
                         ConnectionValidationRequest, CreateChannelResponse,
                         ErrorResponseReceived, MonitorSubcommand, QOSFlags,
                         SearchResponse, Subcommand, VirtualCircuit)

from ..._utils import safe_getsockname
from .._dataclass import (dataclass_from_field_desc, fill_dataclass,
                          is_pva_dataclass_instance)

# Make a dict to hold our tcp sockets.
sockets: Dict[VirtualCircuit, socket.socket] = {}
global_circuits: Dict[AddressTuple, VirtualCircuit] = {}

env = get_environment_variables()
logger = logging.getLogger('caproto.pva.ctx')
serialization_logger = logging.getLogger('caproto.pva.serialization_debug')


# Convenience functions that do both transport and caproto validation/ingest.
def send(circuit, command, pv_name=None):
    if pv_name is not None:
        tags = {'pv': pv_name}
    else:
        tags = None
    buffers_to_send = circuit.send(command, extra=tags)
    sockets[circuit].sendmsg(buffers_to_send)

    if serialization_logger.isEnabledFor(logging.DEBUG):
        to_send = b''.join(buffers_to_send)
        serialization_logger.debug('-> %d bytes: %r', len(to_send), to_send)


def recv(circuit):
    commands = collections.deque()
    bytes_received = sockets[circuit].recv(4096)
    for c, remaining in circuit.recv(bytes_received):
        if c is NEED_DATA:
            break
        circuit.process_command(c)
        commands.append(c)

    return commands


def make_broadcaster_socket() -> Tuple[socket.socket, int]:
    """
    Make and bind a broadcaster socket.

    Returns
    -------
    udp_sock : socket.socket
        The UDP socket.

    port : int
        The bound port.
    """
    udp_sock = bcast_socket()
    udp_sock.bind(('', 0))
    port = udp_sock.getsockname()[1]
    logger.debug('Bound to UDP port %d for search', port)
    return udp_sock, port


def search(pv, udp_sock, udp_port, timeout, max_retries=2):
    """
    Search for a PV over the network by broadcasting over UDP

    Returns: (host, port)
    """
    broadcaster = Broadcaster(our_role=CLIENT, broadcast_port=udp_port)
    broadcaster.client_address = safe_getsockname(udp_sock)

    def send_search(message):
        bytes_to_send = broadcaster.send(message)
        for host, port in get_client_address_list(protocol='PVA'):
            udp_sock.sendto(bytes_to_send, (host, port))
            logger.debug('Search request sent to %r.', (host, port))
            logger.debug('%s', bytes_to_send)

    def check_timeout():
        nonlocal retry_at

        if time.monotonic() >= retry_at:
            send_search(search_req)
            retry_at = time.monotonic() + retry_timeout

        if time.monotonic() - t > timeout:
            raise TimeoutError(
                f"Timed out while awaiting a response searching for {pv!r}"
            )

    # Initial search attempt
    pv_to_cid, search_req = broadcaster.search(pv)
    cid_to_pv = dict((v, k) for k, v in pv_to_cid.items())
    send_search(search_req)

    # Await a search response, and keep track of registration status
    retry_timeout = timeout / max((max_retries, 1))
    t = time.monotonic()
    retry_at = t + retry_timeout

    try:
        orig_timeout = udp_sock.gettimeout()
        udp_sock.settimeout(retry_timeout)
        while True:
            try:
                bytes_received, address = udp_sock.recvfrom(MAX_UDP_RECV)
            except socket.timeout:
                check_timeout()
                continue

            check_timeout()

            commands = broadcaster.recv(bytes_received, address)
            broadcaster.process_commands(commands)
            response_commands = [command for command in commands
                                 if isinstance(command, SearchResponse)]
            for command in response_commands:
                response_pvs = [cid_to_pv.get(cid, None)
                                for cid in command.search_instance_ids]
                if not any(response_pvs):
                    continue

                if command.found:
                    host_port = (command.server_address,
                                 command.server_port)
                    logger.debug('Found %r at %r.', response_pvs,
                                 host_port)
                    return host_port
                else:
                    logger.debug('Server responded: not found %r.',
                                 response_pvs)
    finally:
        udp_sock.settimeout(orig_timeout)


def make_channel(pv_name, udp_sock, udp_port, timeout):
    # log = logging.LoggerAdapter(logging.getLogger('caproto.pva.ch'),
    #                             {'pv': pv_name})
    address = search([pv_name], udp_sock, udp_port, timeout)
    try:
        circuit = global_circuits[address]
    except KeyError:
        circuit = ClientVirtualCircuit(
            our_role=CLIENT, address=address,
            priority=QOSFlags.encode(priority=0, flags=0)
        )
        global_circuits[address] = circuit

    chan = ClientChannel(pv_name, circuit)

    if chan.circuit not in sockets:
        sockets[chan.circuit] = socket.create_connection(chan.circuit.address,
                                                         timeout)
        circuit.our_address = sockets[chan.circuit].getsockname()

    try:
        for command in _receive_commands(circuit, timeout=timeout):
            if isinstance(command, ConnectionValidationRequest):
                if command.auth_nz and 'ca' in command.auth_nz:
                    auth_method = 'ca'
                    auth_data = pva.ChannelAccessAuthentication(
                        user=getpass.getuser(),
                        host=socket.gethostname(),
                    )
                elif command.auth_nz and 'anonymous' in command.auth_nz:
                    auth_method = 'anonymous'
                    auth_data = None
                else:
                    auth_method = ''
                    auth_data = None

                response = circuit.validate_connection(
                    buffer_size=command.server_buffer_size,
                    registry_size=command.server_registry_size,
                    connection_qos=0,
                    auth_nz=auth_method,
                    data=auth_data,
                )
                send(circuit, response)
            elif isinstance(command, ConnectionValidatedResponse):
                logger.debug('Connection validated! Creating channel.')
                create_chan = chan.create()
                send(circuit, create_chan)
            elif isinstance(command, CreateChannelResponse):
                logger.debug('Channel created.')
                return chan

        # if chan.states[CLIENT] is CONNECTED:
        #     break

        logger.debug('Channel created.')
    except Exception:
        sockets[chan.circuit].close()
        raise


def _receive_commands(circuit, timeout):
    t = time.monotonic()
    while True:
        try:
            commands = recv(circuit)
        except socket.timeout:
            commands = []

        for command in commands:
            yield command
            if command is DISCONNECTED:
                raise CaprotoError(
                    'Disconnected while waiting for a response'
                )

        if timeout is not None and time.monotonic() - t > timeout:
            raise TimeoutError("Timeout while awaiting reading.")


def _read(chan, timeout, pvrequest):
    interface_req = chan.read_interface()
    send(chan.circuit, interface_req)

    read_request = chan.read(pvrequest=pvrequest)
    send(chan.circuit, read_request)

    for response in _receive_commands(chan.circuit, timeout):
        if isinstance(response, ChannelFieldInfoResponse):
            # interface = response.field_if
            ...
        elif isinstance(response, ChannelGetResponse):
            if not response.status.is_successful:
                raise ErrorResponseReceived(str(response.status))
            if response.status.message:
                logger.info('Message from server: %s', response.status)

            if response.subcommand == Subcommand.INIT:
                read_request.to_get()
                send(chan.circuit, read_request)
            elif response.subcommand == Subcommand.GET:
                interface = response.pv_data.interface
                value = response.pv_data.data
                dataclass = dataclass_from_field_desc(interface)
                instance = dataclass()
                fill_dataclass(instance, value)
                return response, instance


[docs]def read(pv_name, *, pvrequest='field()', verbose=False, timeout=1): """ Read a Channel. Parameters ---------- pv_name : str The PV name. pvrequest : str, optional The PVRequest, such as 'field(value)'. Defaults to 'field()' for retrieving all data. verbose : boolean, optional Verbose logging. Default is False. timeout : float, optional Default is 1 second. Returns ------- Examples -------- Get the value of a Channel named 'cat'. >>> get('cat') """ udp_sock, udp_port = make_broadcaster_socket() try: udp_sock.settimeout(timeout) chan = make_channel(pv_name, udp_sock, udp_port, timeout) finally: udp_sock.close() try: return _read(chan, timeout, pvrequest=pvrequest) finally: try: if chan.states[CLIENT] is CONNECTED: send(chan.circuit, chan.disconnect()) finally: sockets[chan.circuit].close() del sockets[chan.circuit] del global_circuits[chan.circuit.address]
def _monitor(chan, timeout, pvrequest, maximum_events): """Monitor a channel, using pvrequest, up to maximum_events.""" request: pva.ChannelMonitorRequest = chan.subscribe(pvrequest=pvrequest) send(chan.circuit, request) dataclass = None instance = None event_count = 0 for response in _receive_commands(chan.circuit, timeout=None): if not isinstance(response, ChannelMonitorResponse): continue response: ChannelMonitorResponse if response.subcommand == MonitorSubcommand.INIT: if not response.status.is_successful: raise ErrorResponseReceived(str(response.status)) if response.status.message: logger.info('Message from server: %s', response.status) send(chan.circuit, request.to_start()) field_desc = response.pv_structure_if dataclass = dataclass_from_field_desc(field_desc) instance = dataclass() else: event_data = response.pv_data.data # and 'field' fill_dataclass(instance, event_data) yield response, copy.deepcopy(instance) if maximum_events is not None: event_count += 1 if event_count >= maximum_events: break
[docs]def monitor(pv_name, *, pvrequest='field()', verbose=False, timeout=1, maximum_events=None): """ Monitor a Channel. Parameters ---------- pv_name : str The PV name. pvrequest : str The PVRequest, such as 'field(value)'. verbose : boolean, optional Verbose logging. Default is False. timeout : float, optional Default is 1 second. Returns ------- Examples -------- """ udp_sock, udp_port = make_broadcaster_socket() try: udp_sock.settimeout(timeout) chan = make_channel(pv_name, udp_sock, udp_port, timeout) finally: udp_sock.close() try: yield from _monitor(chan, timeout, pvrequest=pvrequest, maximum_events=maximum_events) finally: try: if chan.states[CLIENT] is CONNECTED: send(chan.circuit, chan.disconnect()) finally: sockets[chan.circuit].close() del sockets[chan.circuit] del global_circuits[chan.circuit.address]
def _read_and_write(chan, timeout, value, pvrequest='field()', cancel_on_keyboardinterrupt=False): """ Read then write structured data to the given channel. """ # TODO: validate dictionary keys against the interface. request: pva.ChannelPutRequest = chan.write(pvrequest=pvrequest) send(chan.circuit, request) if is_pva_dataclass_instance(value): value = dataclasses.asdict(value) if not isinstance(value, dict): value = {'value': value} dataclass = None old_value = None ioid = None try: for command in _receive_commands(chan.circuit, timeout): if not isinstance(command, ChannelPutResponse): continue command = typing.cast(ChannelPutResponse, command) if not command.status.is_successful: raise ErrorResponseReceived(str(command.status)) if command.status.message: logger.info('Message from server: %s', command.status) if command.subcommand == Subcommand.INIT: ioid = command.ioid # Get the latest value with this request send(chan.circuit, request.to_get()) # Then perform the write request dataclass = dataclass_from_field_desc(command.put_structure_if) instance = dataclass() # TODO logic can move up to the circuit? bitset = fill_dataclass(instance, value) request.to_default( put_data=pva.DataWithBitSet(data=instance, bitset=bitset) ) send(chan.circuit, request) elif command.subcommand == Subcommand.GET: old_value = dataclass() bitset = fill_dataclass(old_value, command.put_data.data) elif command.subcommand == Subcommand.DEFAULT: return old_value, command except KeyboardInterrupt: if ioid is not None and cancel_on_keyboardinterrupt: send(chan.circuit, chan.cancel(ioid)) raise
[docs]def read_write_read(pv_name: str, data: dict, *, options: typing.Optional[dict] = None, pvrequest: str = 'field()', cancel_on_keyboardinterrupt: bool = False, timeout=1): """ Write to a Channel, but sandwich the write between two reads. Parameters ---------- pv_name : str The PV name. data : dict or Mapping The structured data to write. pvrequest : str, optional The PVRequest, such as 'field(value)'. Defaults to 'field()' for retrieving all data. options : dict, optional Options to use in the pvRequest. (TODO not yet implemented) timeout : float, optional Timeout for the operation. cancel_on_keyboardinterrupt : bool, optional Cancel the write in the event of a KeyboardInterrupt. """ udp_sock, udp_port = make_broadcaster_socket() try: udp_sock.settimeout(timeout) chan = make_channel(pv_name, udp_sock, udp_port, timeout) finally: udp_sock.close() try: initial, res = _read_and_write( chan, timeout, data, pvrequest=pvrequest, cancel_on_keyboardinterrupt=cancel_on_keyboardinterrupt ) _, final = _read(chan, timeout, pvrequest=pvrequest) finally: try: if chan.states[CLIENT] is CONNECTED: send(chan.circuit, chan.disconnect()) finally: sockets[chan.circuit].close() del sockets[chan.circuit] del global_circuits[chan.circuit.address] return initial, res, final