From cb1d7f014d7e184c8565b3267c1e9170eb606026 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 17 Dec 2025 20:30:27 +0000 Subject: [PATCH 01/24] WIP scruffy first pass at making a user client --- src/blueapi/client/client.py | 289 +++++++++++++++---------- tests/unit_tests/client/test_client.py | 108 --------- 2 files changed, 180 insertions(+), 217 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a9..8152dca491 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,8 @@ import time from concurrent.futures import Future +from functools import singledispatch, singledispatchmethod +from pathlib import Path +from typing import Any, Generic, ParamSpec, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -7,8 +10,13 @@ get_tracer, start_as_current_span, ) +from pydantic import BaseModel, model_serializer -from blueapi.config import ApplicationConfig, MissingStompConfigurationError +from blueapi.config import ( + ApplicationConfig, + ConfigLoader, + MissingStompConfigurationError, +) from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( @@ -18,14 +26,14 @@ OIDCConfig, PlanModel, PlanResponse, + ProtocolInfo, PythonEnvironmentResponse, SourceInfo, TaskRequest, TaskResponse, - TasksListResponse, WorkerTask, ) -from blueapi.worker import TrackableTask, WorkerEvent, WorkerState +from blueapi.worker import WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent @@ -33,12 +41,113 @@ TRACER = get_tracer("client") +import logging + +log = logging.getLogger(__name__) + + +class MissingInstrumentSessionError(Exception): + pass + + +class PlanCache: + def __init__(self, client: "BlueapiClient"): + self._client = client + self._cache = {} + + def _get_device(self, name: str) -> "Plan": + if name.startswith("_"): + raise AttributeError("No such plan") + # TODO: Catch 404 and return AttributeError + if not (plan := self._cache.get(name)): + model = self._client._rest.get_plan(name) + plan = Plan[int]( + name=model.name, + args=model.parameter_schema, + client=self._client, + ) + self._cache[name] = plan + + return plan + + def __getitem__(self, name: str) -> "Plan": + return self._get_device(name) + + def __getattr__(self, name: str) -> "Plan": + return self._get_device(name) + + +class DeviceCache: + def __init__(self, client: "BlueapiClient"): + self._client = client + self._cache = {} + + def _get_device(self, name: str) -> "Device": + if name.startswith("_"): + raise AttributeError("No such device") + if not (device := self._cache.get(name)): + model = self._client._rest.get_device(name) + device = Device(name=model.name, protocols=model.protocols) + self._cache[name] = device + # TODO: Catch 404 and return AttributeError + return device + + def __getitem__(self, name: str) -> "Device": + return self._get_device(name) + + def __getattr__(self, name: str) -> "Device": + return self._get_device(name) + + +class Device(BaseModel): + name: str + protocols: list[ProtocolInfo] + + def __repr__(self): + return f"Device({self.name})" + + @model_serializer(mode="plain") + def _to_json(self): + return self.name + + +PlanArgs = ParamSpec("PlanArgs") + + +class Plan(Generic[PlanArgs]): + def __init__(self, name, args: dict[str, Any], client: "BlueapiClient"): + self._name = name + self._args = args + self._client = client + + def __call__(self, *args: PlanArgs.args, **kwargs: PlanArgs.kwargs): + req = TaskRequest( + name=self._name, + params=self._build_args(args, kwargs), + instrument_session=self._client.instrument_session, + ) + self._client.run_task(req) + + def _build_args(self, args, kwargs): + log.info( + "Building args for %s, using %s and %s", + self._args["properties"], + args, + kwargs, + ) + log.info("Required: %s", self._args["required"]) + return kwargs + + def __repr__(self): + return f"{self._name}({', '.join(self._args['properties'].keys())})" + class BlueapiClient: """Unified client for controlling blueapi""" _rest: BlueapiRestClient _events: EventBusClient | None + _instrument_session: str | None = None def __init__( self, @@ -47,9 +156,32 @@ def __init__( ): self._rest = rest self._events = events + self.plans = PlanCache(self) + self.devices = DeviceCache(self) + + @singledispatchmethod + @classmethod + def from_config(cls, conf) -> Self: + raise ValueError("Unsupported construction arg") + + @from_config.register + @classmethod + def _(cls, config_file: str) -> Self: + return cls.from_config(Path(config_file)) + @from_config.register @classmethod - def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": + def _(cls, config_file: Path) -> Self: + conf = ConfigLoader(ApplicationConfig) + conf.use_values_from_yaml(config_file) + return cls.from_config(conf.load()) + + @from_config.register + @classmethod + def _( + cls, + config: ApplicationConfig, + ) -> Self: session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) @@ -71,53 +203,15 @@ def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": else: return cls(rest) - @start_as_current_span(TRACER) - def get_plans(self) -> PlanResponse: - """ - List plans available - - Returns: - PlanResponse: Plans that can be run - """ - return self._rest.get_plans() - - @start_as_current_span(TRACER, "name") - def get_plan(self, name: str) -> PlanModel: - """ - Get details of a single plan - - Args: - name: Plan name - - Returns: - PlanModel: Details of the plan if found - """ - return self._rest.get_plan(name) - - @start_as_current_span(TRACER) - def get_devices(self) -> DeviceResponse: - """ - List devices available - - Returns: - DeviceResponse: Devices that can be used in plans - """ - - return self._rest.get_devices() - - @start_as_current_span(TRACER, "name") - def get_device(self, name: str) -> DeviceModel: - """ - Get details of a single device - - Args: - name: Device name - - Returns: - DeviceModel: Details of the device if found - """ + @property + def instrument_session(self) -> str: + if self._instrument_session is None: + raise MissingInstrumentSessionError() + return self._instrument_session - return self._rest.get_device(name) + @instrument_session.setter + def instrument_session(self, session): + self._instrument_session = session @start_as_current_span(TRACER) def get_state(self) -> WorkerState: @@ -158,31 +252,6 @@ def resume(self) -> WorkerState: return self._rest.set_state(WorkerState.RUNNING, defer=False) - @start_as_current_span(TRACER, "task_id") - def get_task(self, task_id: str) -> TrackableTask: - """ - Get a task stored by the worker - - Args: - task_id: Unique ID for the task - - Returns: - TrackableTask: Task details - """ - assert task_id, "Task ID not provided!" - return self._rest.get_task(task_id) - - @start_as_current_span(TRACER) - def get_all_tasks(self) -> TasksListResponse: - """ - Get a list of all task stored by the worker - - Returns: - TasksListResponse: List of all Trackable Task - """ - - return self._rest.get_all_tasks() - @start_as_current_span(TRACER) def get_active_task(self) -> WorkerTask: """ @@ -221,7 +290,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self._rest.create_task(task) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -255,7 +324,7 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: with self._events: self._events.subscribe_to_all_events(inner_on_event) - self.start_task(WorkerTask(task_id=task_id)) + self._rest.update_worker_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) @start_as_current_span(TRACER, "task") @@ -271,8 +340,10 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) - worker_response = self.start_task(WorkerTask(task_id=response.task_id)) + response = self._rest.create_task(task) + worker_response = self._rest.update_worker_task( + WorkerTask(task_id=response.task_id) + ) if worker_response.task_id == response.task_id: return response else: @@ -281,47 +352,47 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - @start_as_current_span(TRACER, "task") - def create_task(self, task: TaskRequest) -> TaskResponse: - """ - Create a new task, does not start execution + # @start_as_current_span(TRACER, "task") + # def create_task(self, task: TaskRequest) -> TaskResponse: + # """ + # Create a new task, does not start execution - Args: - task: Request object for task to create on the worker + # Args: + # task: Request object for task to create on the worker - Returns: - TaskResponse: Acknowledgement of request - """ + # Returns: + # TaskResponse: Acknowledgement of request + # """ - return self._rest.create_task(task) + # return self._rest.create_task(task) - @start_as_current_span(TRACER) - def clear_task(self, task_id: str) -> TaskResponse: - """ - Delete a stored task on the worker + # @start_as_current_span(TRACER) + # def clear_task(self, task_id: str) -> TaskResponse: + # """ + # Delete a stored task on the worker - Args: - task_id: ID for the task + # Args: + # task_id: ID for the task - Returns: - TaskResponse: Acknowledgement of request - """ + # Returns: + # TaskResponse: Acknowledgement of request + # """ - return self._rest.clear_task(task_id) + # return self._rest.clear_task(task_id) - @start_as_current_span(TRACER, "task") - def start_task(self, task: WorkerTask) -> WorkerTask: - """ - Instruct the worker to start a stored task immediately + # @start_as_current_span(TRACER, "task") + # def start_task(self, task: WorkerTask) -> WorkerTask: + # """ + # Instruct the worker to start a stored task immediately - Args: - task: WorkerTask to start + # Args: + # task: WorkerTask to start - Returns: - WorkerTask: Acknowledgement of request - """ + # Returns: + # WorkerTask: Acknowledgement of request + # """ - return self._rest.update_worker_task(task) + # return self._rest.update_worker_task(task) @start_as_current_span(TRACER, "reason") def abort(self, reason: str | None = None) -> WorkerState: diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d13ccce80d..b6e13dd04c 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -143,78 +143,10 @@ def test_get_state(client: BlueapiClient): assert client.get_state() == WorkerState.IDLE -def test_get_task(client: BlueapiClient): - assert client.get_task("foo") == TASK - - -def test_get_nonexistent_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_task("baz") - - -def test_get_task_with_empty_id(client: BlueapiClient): - with pytest.raises(AssertionError) as exc: - client.get_task("") - assert str(exc) == "Task ID not provided!" - - -def test_get_all_tasks( - client: BlueapiClient, -): - assert client.get_all_tasks() == TASKS - - -def test_create_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.create_task.assert_called_once_with( - TaskRequest(name="foo", instrument_session="cm12345-1") - ) - - -def test_create_task_does_not_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.update_worker_task.assert_not_called() - - -def test_clear_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.clear_task(task_id="foo") - mock_rest.clear_task.assert_called_once_with("foo") - - def test_get_active_task(client: BlueapiClient): assert client.get_active_task() == ACTIVE_TASK -def test_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.start_task(task=WorkerTask(task_id="bar")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) - - -def test_start_nonexistant_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.update_worker_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.start_task(task=WorkerTask(task_id="bar")) - - def test_create_and_start_task_calls_both_creating_and_starting_endpoints( client: BlueapiClient, mock_rest: Mock, @@ -546,37 +478,6 @@ def test_get_state_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): client.get_state() -def test_get_task_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_task", "task_id"): - client.get_task("foo") - - -def test_get_all_tasks_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, -): - with asserting_span_exporter(exporter, "get_all_tasks"): - client.get_all_tasks() - - -def test_create_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - - -def test_clear_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "clear_task"): - client.clear_task(task_id="foo") - - def test_get_active_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): @@ -584,15 +485,6 @@ def test_get_active_task_span_ok( client.get_active_task() -def test_start_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "start_task", "task"): - client.start_task(task=WorkerTask(task_id="bar")) - - def test_create_and_start_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, From 69b58c3a85a8b2839709380ee6b7a5e24283ad7f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 18 Dec 2025 12:36:33 +0000 Subject: [PATCH 02/24] Cache all devices when first is accessed Allows autocomplete to work when working in a REPL --- src/blueapi/client/client.py | 74 +++++++++++++++--------------------- 1 file changed, 31 insertions(+), 43 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 8152dca491..a99de35309 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,6 +1,7 @@ +import logging import time from concurrent.futures import Future -from functools import singledispatch, singledispatchmethod +from functools import cached_property, singledispatchmethod from pathlib import Path from typing import Any, Generic, ParamSpec, Self @@ -21,11 +22,9 @@ from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, - DeviceResponse, EnvironmentResponse, OIDCConfig, PlanModel, - PlanResponse, ProtocolInfo, PythonEnvironmentResponse, SourceInfo, @@ -41,7 +40,6 @@ TRACER = get_tracer("client") -import logging log = logging.getLogger(__name__) @@ -51,52 +49,36 @@ class MissingInstrumentSessionError(Exception): class PlanCache: - def __init__(self, client: "BlueapiClient"): - self._client = client - self._cache = {} - - def _get_device(self, name: str) -> "Plan": - if name.startswith("_"): - raise AttributeError("No such plan") - # TODO: Catch 404 and return AttributeError - if not (plan := self._cache.get(name)): - model = self._client._rest.get_plan(name) - plan = Plan[int]( - name=model.name, - args=model.parameter_schema, - client=self._client, + def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): + self._cache = { + model.name: Plan( + name=model.name, args=model.parameter_schema, client=client ) - self._cache[name] = plan - - return plan - - def __getitem__(self, name: str) -> "Plan": - return self._get_device(name) + for model in plans + } + for name, plan in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, plan) def __getattr__(self, name: str) -> "Plan": - return self._get_device(name) + raise AttributeError(f"No plan named '{name}' available") class DeviceCache: - def __init__(self, client: "BlueapiClient"): + def __init__(self, client: "BlueapiClient", devices: list[DeviceModel]): self._client = client - self._cache = {} - - def _get_device(self, name: str) -> "Device": - if name.startswith("_"): - raise AttributeError("No such device") - if not (device := self._cache.get(name)): - model = self._client._rest.get_device(name) - device = Device(name=model.name, protocols=model.protocols) - self._cache[name] = device - # TODO: Catch 404 and return AttributeError - return device - - def __getitem__(self, name: str) -> "Device": - return self._get_device(name) + self._cache = { + model.name: Device(name=model.name, protocols=model.protocols) + for model in devices + } + for name, device in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, device) def __getattr__(self, name: str) -> "Device": - return self._get_device(name) + raise AttributeError(f"No device named '{name}' available") class Device(BaseModel): @@ -156,8 +138,14 @@ def __init__( ): self._rest = rest self._events = events - self.plans = PlanCache(self) - self.devices = DeviceCache(self) + + @cached_property + def plans(self) -> PlanCache: + return PlanCache(self, self._rest.get_plans().plans) + + @cached_property + def devices(self) -> DeviceCache: + return DeviceCache(self, self._rest.get_devices().devices) @singledispatchmethod @classmethod From 7d03b9cc2197c4b64d7674f8e1e8f0c9a2c9f728 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 18 Dec 2025 14:57:38 +0000 Subject: [PATCH 03/24] More client refactoring * Support args as well as kwargs when running plans * Get child devices via attributes on parent devices * Make more methods into properties --- src/blueapi/client/client.py | 128 ++++++++++++++++++++++------------- 1 file changed, 82 insertions(+), 46 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index a99de35309..fbe2641288 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -3,7 +3,7 @@ from concurrent.futures import Future from functools import cached_property, singledispatchmethod from pathlib import Path -from typing import Any, Generic, ParamSpec, Self +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -11,7 +11,6 @@ get_tracer, start_as_current_span, ) -from pydantic import BaseModel, model_serializer from blueapi.config import ( ApplicationConfig, @@ -21,7 +20,6 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( - DeviceModel, EnvironmentResponse, OIDCConfig, PlanModel, @@ -66,59 +64,103 @@ def __getattr__(self, name: str) -> "Plan": class DeviceCache: - def __init__(self, client: "BlueapiClient", devices: list[DeviceModel]): - self._client = client + def __init__(self, rest: BlueapiRestClient): + self._rest = rest self._cache = { - model.name: Device(name=model.name, protocols=model.protocols) - for model in devices + model.name: DeviceRef( + name=model.name, cache=self, protocols=model.protocols + ) + for model in rest.get_devices().devices } for name, device in self._cache.items(): if name.startswith("_"): continue setattr(self, name, device) - def __getattr__(self, name: str) -> "Device": - raise AttributeError(f"No device named '{name}' available") - - -class Device(BaseModel): - name: str + def __getitem__(self, name: str) -> "DeviceRef": + if dev := self._cache.get(name): + return dev + try: + model = self._rest.get_device(name) + device = DeviceRef(name=name, cache=self, protocols=model.protocols) + self._cache[name] = device + setattr(self, model.name, device) + return device + except KeyError: + pass + raise AttributeError(f"No device name '{name}' available") + + def __getattr__(self, name: str) -> "DeviceRef": + if name.startswith("_"): + return super().__getattribute__(name) + return self[name] + + +class DeviceRef(str): protocols: list[ProtocolInfo] + _cache: DeviceCache - def __repr__(self): - return f"Device({self.name})" - - @model_serializer(mode="plain") - def _to_json(self): - return self.name + def __new__(cls, name, cache, protocols): + instance = super().__new__(cls, name) + instance.protocols = protocols + instance._cache = cache + return instance + def __getattr__(self, name) -> "DeviceRef": + if name.startswith("_"): + raise AttributeError(f"No child device named {name}") + return self._cache[f"{self}.{name}"] -PlanArgs = ParamSpec("PlanArgs") + def __repr__(self): + return f"Device({self})" -class Plan(Generic[PlanArgs]): +class Plan: def __init__(self, name, args: dict[str, Any], client: "BlueapiClient"): self._name = name self._args = args self._client = client - def __call__(self, *args: PlanArgs.args, **kwargs: PlanArgs.kwargs): + def __call__(self, *args, **kwargs): req = TaskRequest( name=self._name, - params=self._build_args(args, kwargs), + params=self._build_args(*args, **kwargs), instrument_session=self._client.instrument_session, ) self._client.run_task(req) - def _build_args(self, args, kwargs): + def _build_args(self, *args, **kwargs): + props = self._args["properties"] + required = self._args["required"] + log.info( "Building args for %s, using %s and %s", - self._args["properties"], + "[" + ",".join(props) + "]", args, kwargs, ) - log.info("Required: %s", self._args["required"]) - return kwargs + + if len(args) > len(props): + raise TypeError(f"{self._name} got too many arguments") + if extra := {k for k in kwargs if k not in props}: + raise TypeError(f"{self._name} got unexpected arguments: {extra}") + + params = {} + # Initially fill parameters using positional args assuming the order + # from the parameter_schema + for req, arg in zip(props, args, strict=False): + params[req] = arg + + # Then append any values given via kwargs + for key, value in kwargs.items(): + # If we've already assumed a positional arg was this value, bail out + if key in params: + raise TypeError(f"{self._name} got multiple values for {key}") + params[key] = value + + if missing := {k for k in required if k not in params}: + raise TypeError(f"Missing argument(s) for {missing}") + return params def __repr__(self): return f"{self._name}({', '.join(self._args['properties'].keys())})" @@ -145,7 +187,7 @@ def plans(self) -> PlanCache: @cached_property def devices(self) -> DeviceCache: - return DeviceCache(self, self._rest.get_devices().devices) + return DeviceCache(self._rest) @singledispatchmethod @classmethod @@ -198,11 +240,13 @@ def instrument_session(self) -> str: return self._instrument_session @instrument_session.setter - def instrument_session(self, session): + def instrument_session(self, session: str): + log.debug("Setting instrument_session to %s", session) self._instrument_session = session + @property @start_as_current_span(TRACER) - def get_state(self) -> WorkerState: + def state(self) -> WorkerState: """ Get current state of the blueapi worker @@ -240,8 +284,9 @@ def resume(self) -> WorkerState: return self._rest.set_state(WorkerState.RUNNING, defer=False) + @property @start_as_current_span(TRACER) - def get_active_task(self) -> WorkerTask: + def active_task(self) -> WorkerTask: """ Get the currently active task, if any @@ -417,15 +462,10 @@ def stop(self) -> WorkerState: return self._rest.cancel_current_task(WorkerState.STOPPING) + @property @start_as_current_span(TRACER) - def get_environment(self) -> EnvironmentResponse: - """ - Get details of the worker environment - - Returns: - EnvironmentResponse: Details of the worker - environment. - """ + def environment(self) -> EnvironmentResponse: + """Details of the worker environment""" return self._rest.get_environment() @@ -492,14 +532,10 @@ def _wait_for_reload( "seconds, a server restart is recommended" ) + @property @start_as_current_span(TRACER) - def get_oidc_config(self) -> OIDCConfig | None: - """ - Get oidc config from the server - - Returns: - OIDCConfig: Details of the oidc Config - """ + def oidc_config(self) -> OIDCConfig | None: + """OIDC config from the server""" return self._rest.get_oidc_config() From 897220caeeedcea6febf03b9756f6a94f7edbe54 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 18 Dec 2025 15:49:44 +0000 Subject: [PATCH 04/24] Update CLI to use update client --- src/blueapi/cli/cli.py | 8 ++-- src/blueapi/cli/format.py | 74 ++++++++++++++++++++++-------------- src/blueapi/client/client.py | 58 +++++++++++++++++----------- 3 files changed, 85 insertions(+), 55 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c39ff13dff..995ad79f3a 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -204,7 +204,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_plans()) + obj["fmt"].display([p.model for p in client.plans]) @controller.command(name="devices") @@ -213,7 +213,7 @@ def get_plans(obj: dict) -> None: def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_devices()) + obj["fmt"].display([dev.model for dev in client.devices]) @controller.command(name="listen") @@ -345,7 +345,7 @@ def get_state(obj: dict) -> None: """Print the current state of the worker""" client: BlueapiClient = obj["client"] - print(client.get_state().name) + print(client.state.name) @controller.command(name="pause") @@ -428,7 +428,7 @@ def env( status = client.reload_environment(timeout=timeout) print("Environment is initialized") else: - status = client.get_environment() + status = client.environment print(status) diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 490e57cfa5..5d14edf0f7 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -12,7 +12,9 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( + DeviceModel, DeviceResponse, + PlanModel, PlanResponse, PythonEnvironmentResponse, SourceInfo, @@ -54,17 +56,21 @@ def display_full(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc).strip(), " ")) - if schema := plan.parameter_schema: - print(" Schema") - print(indent(json.dumps(schema, indent=2), " ")) + display_full(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc).strip(), " ")) + if schema: + print(" Schema") + print(indent(json.dumps(schema, indent=2), " ")) case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - for proto in dev.protocols: - print(f" {proto}") + display_full(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + for proto in protocols: + print(f" {proto}") case DataEvent(name=name, doc=doc): print(f"{name.title()}:{fmt_dict(doc)}") case WorkerEvent(state=st, task_status=task): @@ -92,6 +98,9 @@ def display_full(obj: Any, stream: Stream): case BaseModel(): print(obj.__class__.__name__, end="") print(fmt_dict(obj.model_dump())) + case list(): + for item in obj: + display_full(item, stream) case other: FALLBACK(other, stream=stream) @@ -100,11 +109,13 @@ def display_json(obj: Any, stream: Stream): print = partial(builtins.print, file=stream) match obj: case PlanResponse(plans=plans): - print(json.dumps([p.model_dump() for p in plans], indent=2)) + display_json(plans, stream) case DeviceResponse(devices=devices): - print(json.dumps([d.model_dump() for d in devices], indent=2)) + display_json(devices, stream) case BaseModel(): print(json.dumps(obj.model_dump())) + case list(): + print(json.dumps([it.model_dump() for it in obj], indent=2)) case _: print(json.dumps(obj)) @@ -114,26 +125,30 @@ def display_compact(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) - if schema := plan.parameter_schema: - print(" Args") - for arg, spec in schema.get("properties", {}).items(): - req = arg in schema.get("required", {}) - print(f" {arg}={_describe_type(spec, req)}") + display_compact(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) + if schema: + print(" Args") + for arg, spec in schema.get("properties", {}).items(): + req = arg in schema.get("required", {}) + print(f" {arg}={_describe_type(spec, req)}") case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - print( - indent( - textwrap.fill( - ", ".join(str(proto) for proto in dev.protocols), - 80, - ), - " ", - ) + display_compact(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + print( + indent( + textwrap.fill( + ", ".join(str(proto) for proto in protocols), + 80, + ), + " ", ) + ) case DataEvent(name=name): print(f"Data Event: {name}") case WorkerEvent(state=state): @@ -160,6 +175,9 @@ def display_compact(obj: Any, stream: Stream): extra += " (Scratch)" print(f"- {package.name} @ ({package.version}){extra}") + case list(): + for item in obj: + display_compact(item, stream) case other: FALLBACK(other, stream=stream) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index fbe2641288..ef26110d57 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -3,7 +3,7 @@ from concurrent.futures import Future from functools import cached_property, singledispatchmethod from pathlib import Path -from typing import Any, Self +from typing import Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -20,10 +20,10 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( + DeviceModel, EnvironmentResponse, OIDCConfig, PlanModel, - ProtocolInfo, PythonEnvironmentResponse, SourceInfo, TaskRequest, @@ -49,9 +49,7 @@ class MissingInstrumentSessionError(Exception): class PlanCache: def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): self._cache = { - model.name: Plan( - name=model.name, args=model.parameter_schema, client=client - ) + model.name: Plan(name=model.name, model=model, client=client) for model in plans } for name, plan in self._cache.items(): @@ -62,14 +60,15 @@ def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): def __getattr__(self, name: str) -> "Plan": raise AttributeError(f"No plan named '{name}' available") + def __iter__(self): + return iter(self._cache.values()) + class DeviceCache: def __init__(self, rest: BlueapiRestClient): self._rest = rest self._cache = { - model.name: DeviceRef( - name=model.name, cache=self, protocols=model.protocols - ) + model.name: DeviceRef(name=model.name, cache=self, model=model) for model in rest.get_devices().devices } for name, device in self._cache.items(): @@ -82,7 +81,7 @@ def __getitem__(self, name: str) -> "DeviceRef": return dev try: model = self._rest.get_device(name) - device = DeviceRef(name=name, cache=self, protocols=model.protocols) + device = DeviceRef(name=name, cache=self, model=model) self._cache[name] = device setattr(self, model.name, device) return device @@ -95,14 +94,17 @@ def __getattr__(self, name: str) -> "DeviceRef": return super().__getattribute__(name) return self[name] + def __iter__(self): + return iter(self._cache.values()) + class DeviceRef(str): - protocols: list[ProtocolInfo] + model: DeviceModel _cache: DeviceCache - def __new__(cls, name, cache, protocols): + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): instance = super().__new__(cls, name) - instance.protocols = protocols + instance.model = model instance._cache = cache return instance @@ -116,10 +118,11 @@ def __repr__(self): class Plan: - def __init__(self, name, args: dict[str, Any], client: "BlueapiClient"): + def __init__(self, name, model: PlanModel, client: "BlueapiClient"): self._name = name - self._args = args + self.model = model self._client = client + self.__doc__ = model.description def __call__(self, *args, **kwargs): req = TaskRequest( @@ -129,26 +132,35 @@ def __call__(self, *args, **kwargs): ) self._client.run_task(req) - def _build_args(self, *args, **kwargs): - props = self._args["properties"] - required = self._args["required"] + @property + def help_text(self) -> str: + return self.model.description or f"Plan {self!r}" + @property + def properties(self) -> set[str]: + return self.model.parameter_schema["properties"] + + @property + def required(self) -> list[str]: + return self.model.parameter_schema["required"] + + def _build_args(self, *args, **kwargs): log.info( "Building args for %s, using %s and %s", - "[" + ",".join(props) + "]", + "[" + ",".join(self.properties) + "]", args, kwargs, ) - if len(args) > len(props): + if len(args) > len(self.properties): raise TypeError(f"{self._name} got too many arguments") - if extra := {k for k in kwargs if k not in props}: + if extra := {k for k in kwargs if k not in self.properties}: raise TypeError(f"{self._name} got unexpected arguments: {extra}") params = {} # Initially fill parameters using positional args assuming the order # from the parameter_schema - for req, arg in zip(props, args, strict=False): + for req, arg in zip(self.properties, args, strict=False): params[req] = arg # Then append any values given via kwargs @@ -158,12 +170,12 @@ def _build_args(self, *args, **kwargs): raise TypeError(f"{self._name} got multiple values for {key}") params[key] = value - if missing := {k for k in required if k not in params}: + if missing := {k for k in self.required if k not in params}: raise TypeError(f"Missing argument(s) for {missing}") return params def __repr__(self): - return f"{self._name}({', '.join(self._args['properties'].keys())})" + return f"{self._name}({', '.join(self.properties)})" class BlueapiClient: From 01124e5756063cb8698e7c51cd6cbff2e4132c12 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 18 Dec 2025 15:55:22 +0000 Subject: [PATCH 05/24] Use "BlueapiClient" instead of Self for classmethods For some reason pyright can't figure out what the return type is otherwise and you don't get completion. --- src/blueapi/client/client.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index ef26110d57..328eb364e1 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -3,7 +3,6 @@ from concurrent.futures import Future from functools import cached_property, singledispatchmethod from pathlib import Path -from typing import Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -203,17 +202,17 @@ def devices(self) -> DeviceCache: @singledispatchmethod @classmethod - def from_config(cls, conf) -> Self: + def from_config(cls, conf) -> "BlueapiClient": raise ValueError("Unsupported construction arg") @from_config.register @classmethod - def _(cls, config_file: str) -> Self: + def _(cls, config_file: str) -> "BlueapiClient": return cls.from_config(Path(config_file)) @from_config.register @classmethod - def _(cls, config_file: Path) -> Self: + def _(cls, config_file: Path) -> "BlueapiClient": conf = ConfigLoader(ApplicationConfig) conf.use_values_from_yaml(config_file) return cls.from_config(conf.load()) @@ -223,7 +222,7 @@ def _(cls, config_file: Path) -> Self: def _( cls, config: ApplicationConfig, - ) -> Self: + ) -> "BlueapiClient": session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) From d9e2d3b2bc85cf23438fa3a1b104acc96702b289 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 11:20:43 +0000 Subject: [PATCH 06/24] Remove single dispatch stuff It broke type checking --- src/blueapi/client/client.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 328eb364e1..354c5f6d0e 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,8 +1,9 @@ import logging import time from concurrent.futures import Future -from functools import cached_property, singledispatchmethod +from functools import cached_property from pathlib import Path +from typing import Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -200,29 +201,17 @@ def plans(self) -> PlanCache: def devices(self) -> DeviceCache: return DeviceCache(self._rest) - @singledispatchmethod @classmethod - def from_config(cls, conf) -> "BlueapiClient": - raise ValueError("Unsupported construction arg") - - @from_config.register - @classmethod - def _(cls, config_file: str) -> "BlueapiClient": - return cls.from_config(Path(config_file)) - - @from_config.register - @classmethod - def _(cls, config_file: Path) -> "BlueapiClient": + def from_config_file(cls, config_file: str) -> Self: conf = ConfigLoader(ApplicationConfig) - conf.use_values_from_yaml(config_file) + conf.use_values_from_yaml(Path(config_file)) return cls.from_config(conf.load()) - @from_config.register @classmethod - def _( + def from_config( cls, config: ApplicationConfig, - ) -> "BlueapiClient": + ) -> Self: session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) From 8dbee860477eea878bc1479f34c2e1dc3be4f2b9 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 15:26:17 +0000 Subject: [PATCH 07/24] Make name a public attribute of Plan --- src/blueapi/client/client.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 354c5f6d0e..098112858f 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -119,14 +119,14 @@ def __repr__(self): class Plan: def __init__(self, name, model: PlanModel, client: "BlueapiClient"): - self._name = name + self.name = name self.model = model self._client = client self.__doc__ = model.description def __call__(self, *args, **kwargs): req = TaskRequest( - name=self._name, + name=self.name, params=self._build_args(*args, **kwargs), instrument_session=self._client.instrument_session, ) @@ -153,9 +153,9 @@ def _build_args(self, *args, **kwargs): ) if len(args) > len(self.properties): - raise TypeError(f"{self._name} got too many arguments") + raise TypeError(f"{self.name} got too many arguments") if extra := {k for k in kwargs if k not in self.properties}: - raise TypeError(f"{self._name} got unexpected arguments: {extra}") + raise TypeError(f"{self.name} got unexpected arguments: {extra}") params = {} # Initially fill parameters using positional args assuming the order @@ -167,7 +167,7 @@ def _build_args(self, *args, **kwargs): for key, value in kwargs.items(): # If we've already assumed a positional arg was this value, bail out if key in params: - raise TypeError(f"{self._name} got multiple values for {key}") + raise TypeError(f"{self.name} got multiple values for {key}") params[key] = value if missing := {k for k in self.required if k not in params}: @@ -175,7 +175,7 @@ def _build_args(self, *args, **kwargs): return params def __repr__(self): - return f"{self._name}({', '.join(self.properties)})" + return f"{self.name}({', '.join(self.properties)})" class BlueapiClient: From 740b41934bda18fecdf49c398ca2687fe36b5734 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 15:26:35 +0000 Subject: [PATCH 08/24] Add spans to plans and devices properties --- src/blueapi/client/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 098112858f..2210e65d25 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -194,10 +194,12 @@ def __init__( self._events = events @cached_property + @start_as_current_span(TRACER) def plans(self) -> PlanCache: return PlanCache(self, self._rest.get_plans().plans) @cached_property + @start_as_current_span(TRACER) def devices(self) -> DeviceCache: return DeviceCache(self._rest) From 0fdce2cbdac7e4c0117ae82019436c93e3972488 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 15:26:45 +0000 Subject: [PATCH 09/24] Remove dead comments --- src/blueapi/client/client.py | 42 ------------------------------------ 1 file changed, 42 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 2210e65d25..955f3b99c8 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -387,48 +387,6 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - # @start_as_current_span(TRACER, "task") - # def create_task(self, task: TaskRequest) -> TaskResponse: - # """ - # Create a new task, does not start execution - - # Args: - # task: Request object for task to create on the worker - - # Returns: - # TaskResponse: Acknowledgement of request - # """ - - # return self._rest.create_task(task) - - # @start_as_current_span(TRACER) - # def clear_task(self, task_id: str) -> TaskResponse: - # """ - # Delete a stored task on the worker - - # Args: - # task_id: ID for the task - - # Returns: - # TaskResponse: Acknowledgement of request - # """ - - # return self._rest.clear_task(task_id) - - # @start_as_current_span(TRACER, "task") - # def start_task(self, task: WorkerTask) -> WorkerTask: - # """ - # Instruct the worker to start a stored task immediately - - # Args: - # task: WorkerTask to start - - # Returns: - # WorkerTask: Acknowledgement of request - # """ - - # return self._rest.update_worker_task(task) - @start_as_current_span(TRACER, "reason") def abort(self, reason: str | None = None) -> WorkerState: """ From cc10b98923a0c445100e5ee7eafae3b2e9665fbf Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 15:44:48 +0000 Subject: [PATCH 10/24] Make optional parameters clear in Plan repr --- src/blueapi/client/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 955f3b99c8..84835a02d4 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -2,6 +2,7 @@ import time from concurrent.futures import Future from functools import cached_property +from itertools import chain from pathlib import Path from typing import Self @@ -175,7 +176,9 @@ def _build_args(self, *args, **kwargs): return params def __repr__(self): - return f"{self.name}({', '.join(self.properties)})" + opts = [p for p in self.properties if p not in self.required] + params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) + return f"{self.name}({params})" class BlueapiClient: From 2bf41a408daf47df89a66b8e36e05d6e1f726c56 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 19 Dec 2025 16:34:55 +0000 Subject: [PATCH 11/24] Update client and system tests to use new client changes --- src/blueapi/client/client.py | 2 +- tests/system_tests/test_blueapi_system.py | 148 ++++++++++++---------- tests/unit_tests/client/test_client.py | 75 ++++++----- 3 files changed, 126 insertions(+), 99 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 84835a02d4..585481d197 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -88,7 +88,7 @@ def __getitem__(self, name: str) -> "DeviceRef": return device except KeyError: pass - raise AttributeError(f"No device name '{name}' available") + raise AttributeError(f"No device named '{name}' available") def __getattr__(self, name: str) -> "DeviceRef": if name.startswith("_"): diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 0271331a6d..fdb5372112 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -18,6 +18,7 @@ ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError from blueapi.client.rest import BlueskyRequestError +from blueapi.client.rest import BlueapiRestClient, UnknownPlanError from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -130,7 +131,7 @@ def client_with_stomp() -> Generator[BlueapiClient]: def wait_for_server(client: BlueapiClient): for _ in range(20): try: - client.get_environment() + _ = client.environment return except ConnectionError: ... @@ -149,6 +150,11 @@ def client() -> Generator[BlueapiClient]: yield BlueapiClient.from_config(config=ApplicationConfig()) +@pytest.fixture +def rest_client(client: BlueapiClient) -> BlueapiRestClient: + return client._rest + + @pytest.fixture def expected_plans() -> PlanResponse: return TypeAdapter(PlanResponse).validate_json( @@ -180,9 +186,9 @@ def blueapi_client_get_methods() -> list[str]: @pytest.fixture(autouse=True) -def clean_existing_tasks(client: BlueapiClient): - for task in client.get_all_tasks().tasks: - client.clear_task(task.task_id) +def clean_existing_tasks(rest_client: BlueapiRestClient): + for task in rest_client.get_all_tasks().tasks: + rest_client.clear_task(task.task_id) yield @@ -225,15 +231,15 @@ def test_cannot_access_endpoints( def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): - assert client_without_auth.get_oidc_config() == OIDCConfig( + assert client_without_auth.oidc_config == OIDCConfig( well_known_url=KEYCLOAK_BASE_URL + "realms/master/.well-known/openid-configuration", client_id="ixx-cli-blueapi", ) -def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): - retrieved_plans = client.get_plans() +def test_get_plans(rest_client: BlueapiRestClient, expected_plans: PlanResponse): + retrieved_plans = rest_client.get_plans() retrieved_plans.plans.sort(key=lambda x: x.name) expected_plans.plans.sort(key=lambda x: x.name) @@ -242,40 +248,52 @@ def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): for plan in expected_plans.plans: - assert client.get_plan(plan.name) == plan + assert getattr(client.plans, plan.name).model == plan -def test_get_non_existent_plan(client: BlueapiClient): +def test_get_non_existent_plan(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_plan("Not exists") + rest_client.get_plan("Not exists") + +def test_client_non_existent_plan(client: BlueapiClient): + with pytest.raises(AttributeError, match="No plan named 'missing' available"): + _ = client.plans.missing -def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): - retrieved_devices = client.get_devices() + +def test_get_devices(rest_client: BlueapiRestClient, expected_devices: DeviceResponse): + retrieved_devices = rest_client.get_devices() retrieved_devices.devices.sort(key=lambda x: x.name) expected_devices.devices.sort(key=lambda x: x.name) assert retrieved_devices == expected_devices -def test_get_device_by_name(client: BlueapiClient, expected_devices: DeviceResponse): +def test_get_device_by_name( + rest_client: BlueapiRestClient, expected_devices: DeviceResponse +): for device in expected_devices.devices: - assert client.get_device(device.name) == device + assert rest_client.get_device(device.name) == device -def test_get_non_existent_device(client: BlueapiClient): +def test_get_non_existent_device(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_device("Not exists") + rest_client.get_device("Not exists") + + +def test_client_non_existent_device(client: BlueapiClient): + with pytest.raises(AttributeError, match="No device named 'missing' available"): + _ = client.devices.missing -def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) - client.clear_task(create_task.task_id) +def test_create_task_and_delete_task_by_id(rest_client: BlueapiRestClient): + create_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.clear_task(create_task.task_id) -def test_instrument_session_propagated(client: BlueapiClient): - response = client.create_task(_SIMPLE_TASK) - trackable_task = client.get_task(response.task_id) +def test_instrument_session_propagated(rest_client: BlueapiRestClient): + response = rest_client.create_task(_SIMPLE_TASK) + trackable_task = rest_client.get_task(response.task_id) assert trackable_task.task.metadata == { "instrument_session": AUTHORIZED_INSTRUMENT_SESSION, "tiled_access_tags": [ @@ -284,9 +302,9 @@ def test_instrument_session_propagated(client: BlueapiClient): } -def test_create_task_validation_error(client: BlueapiClient): +def test_create_task_validation_error(rest_client: BlueapiRestClient): with pytest.raises(BlueskyRequestError, match="Internal Server Error"): - client.create_task( + rest_client.create_task( TaskRequest( name="Not-exists", params={"Not-exists": 0.0}, @@ -295,26 +313,26 @@ def test_create_task_validation_error(client: BlueapiClient): ) -def test_get_all_tasks(client: BlueapiClient): +def test_get_all_tasks(rest_client: BlueapiRestClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = rest_client.create_task(task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] - task_list = client.get_all_tasks() + task_list = rest_client.get_all_tasks() for trackable_task in task_list.tasks: assert trackable_task.task_id in task_ids assert trackable_task.is_complete is False and trackable_task.is_pending is True for task_id in task_ids: - client.clear_task(task_id) + rest_client.clear_task(task_id) -def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) +def test_get_task_by_id(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) - get_task = client.get_task(created_task.task_id) + get_task = rest_client.get_task(created_task.task_id) assert ( get_task.task_id == created_task.task_id and get_task.is_pending @@ -322,45 +340,45 @@ def test_get_task_by_id(client: BlueapiClient): and len(get_task.errors) == 0 ) - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_get_non_existent_task(client: BlueapiClient): +def test_get_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_task("Not-exists") + rest_client.get_task("Not-exists") -def test_delete_non_existent_task(client: BlueapiClient): +def test_delete_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.clear_task("Not-exists") + rest_client.clear_task("Not-exists") -def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) - client.start_task(WorkerTask(task_id=created_task.task_id)) - active_task = client.get_active_task() +def test_put_worker_task(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.update_worker_task(WorkerTask(task_id=created_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == created_task.task_id - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) +def test_put_worker_task_fails_if_not_idle(rest_client: BlueapiRestClient): + small_task = rest_client.create_task(_SIMPLE_TASK) + long_task = rest_client.create_task(_LONG_TASK) - client.start_task(WorkerTask(task_id=long_task.task_id)) - active_task = client.get_active_task() + rest_client.update_worker_task(WorkerTask(task_id=long_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == long_task.task_id with pytest.raises(BlueskyRemoteControlError) as exception: - client.start_task(WorkerTask(task_id=small_task.task_id)) + rest_client.update_worker_task(WorkerTask(task_id=small_task.task_id)) assert "" in str(exception) - client.abort() - client.clear_task(small_task.task_id) - client.clear_task(long_task.task_id) + rest_client.cancel_current_task(WorkerState.ABORTING) + rest_client.clear_task(small_task.task_id) + rest_client.clear_task(long_task.task_id) def test_get_worker_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE + assert client.state == WorkerState.IDLE def test_set_state_transition_error(client: BlueapiClient): @@ -372,10 +390,10 @@ def test_set_state_transition_error(client: BlueapiClient): assert "" in str(exception) -def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) - task_by_pending = client.get_all_tasks() +def test_get_task_by_status(rest_client: BlueapiRestClient): + task_1 = rest_client.create_task(_SIMPLE_TASK) + task_2 = rest_client.create_task(_SIMPLE_TASK) + task_by_pending = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) assert len(task_by_pending.tasks) == 2 @@ -384,13 +402,13 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is False and trackable_task.is_pending is True - client.start_task(WorkerTask(task_id=task_1.task_id)) - while not client.get_task(task_1.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_1.task_id)) + while not rest_client.get_task(task_1.task_id).is_complete: time.sleep(0.1) - client.start_task(WorkerTask(task_id=task_2.task_id)) - while not client.get_task(task_2.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_2.task_id)) + while not rest_client.get_task(task_2.task_id).is_complete: time.sleep(0.1) - task_by_completed = client.get_all_tasks() + task_by_completed = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.COMPLETE) assert len(task_by_completed.tasks) == 2 @@ -399,8 +417,8 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is True and trackable_task.is_pending is False - client.clear_task(task_id=task_1.task_id) - client.clear_task(task_id=task_2.task_id) + rest_client.clear_task(task_id=task_1.task_id) + rest_client.clear_task(task_id=task_2.task_id) def test_progress_with_stomp(client_with_stomp: BlueapiClient): @@ -441,13 +459,13 @@ def on_event(event: AnyEvent): def test_get_current_state_of_environment(client: BlueapiClient): - assert client.get_environment().initialized + assert client.environment.initialized def test_delete_current_environment(client: BlueapiClient): - old_env = client.get_environment() + old_env = client.environment client.reload_environment() - new_env = client.get_environment() + new_env = client.environment assert new_env.initialized assert new_env.environment_id != old_env.environment_id assert new_env.error_message is None diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index b6e13dd04c..eac58cdf74 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -20,6 +20,7 @@ EnvironmentResponse, PlanModel, PlanResponse, + ProtocolInfo, TaskRequest, TaskResponse, TasksListResponse, @@ -72,9 +73,9 @@ def mock_rest() -> BlueapiRestClient: mock = Mock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS - mock.get_plan.return_value = PLAN + mock.get_plan.side_effect = lambda n: {p.name: p for p in PLANS.plans}[n] mock.get_devices.return_value = DEVICES - mock.get_device.return_value = DEVICE + mock.get_device.side_effect = lambda n: {d.name: d for d in DEVICES.devices}[n] mock.get_state.return_value = WorkerState.IDLE mock.get_task.return_value = TASK mock.get_all_tasks.return_value = TASKS @@ -106,45 +107,53 @@ def client_with_events(mock_rest: Mock, mock_events: MagicMock): def test_get_plans(client: BlueapiClient): - assert client.get_plans() == PLANS + assert PlanResponse(plans=[p.model for p in client.plans]) == PLANS def test_get_plan(client: BlueapiClient): - assert client.get_plan("foo") == PLAN + assert client.plans.foo.model == PLAN def test_get_nonexistant_plan( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_plan.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_plan("baz") + with pytest.raises(AttributeError): + _ = client.plans.fizz_buzz.model def test_get_devices(client: BlueapiClient): - assert client.get_devices() == DEVICES + assert DeviceResponse(devices=[d.model for d in client.devices]) == DEVICES def test_get_device(client: BlueapiClient): - assert client.get_device("foo") == DEVICE + assert client.devices.foo.model == DEVICE -def test_get_nonexistant_device( +def test_get_nonexistent_device( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_device.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_device("baz") + with pytest.raises(AttributeError): + _ = client.devices.baz + + +def test_get_child_device(mock_rest: Mock, client: BlueapiClient): + mock_rest.get_device.side_effect = ( + lambda name: DeviceModel(name="foo.x", protocols=[ProtocolInfo(name="One")]) + if name == "foo.x" + else None + ) + foo = client.devices.foo + assert foo == "foo" + x = client.devices.foo.x + assert x == "foo.x" def test_get_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE + assert client.state == WorkerState.IDLE def test_get_active_task(client: BlueapiClient): - assert client.get_active_task() == ACTIVE_TASK + assert client.active_task == ACTIVE_TASK def test_create_and_start_task_calls_both_creating_and_starting_endpoints( @@ -198,7 +207,7 @@ def test_create_and_start_task_fails_if_task_start_fails( def test_get_environment(client: BlueapiClient): - assert client.get_environment() == ENV + assert client.environment == ENV def test_reload_environment( @@ -454,35 +463,35 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plans"): - client.get_plans() + with asserting_span_exporter(exporter, "plans"): + _ = client.plans def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plan", "name"): - client.get_plan("foo") + with asserting_span_exporter(exporter, "plans"): + _ = client.plans.foo def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_devices"): - client.get_devices() + with asserting_span_exporter(exporter, "devices"): + _ = client.devices def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_device", "name"): - client.get_device("foo") + with asserting_span_exporter(exporter, "devices"): + _ = client.devices.foo -def test_get_state_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_state"): - client.get_state() +def test_get_state_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "state"): + _ = client.state def test_get_active_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_active_task"): - client.get_active_task() + with asserting_span_exporter(exporter, "active_task"): + _ = client.active_task def test_create_and_start_task_span_ok( @@ -501,8 +510,8 @@ def test_create_and_start_task_span_ok( def test_get_environment_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_environment"): - client.get_environment() + with asserting_span_exporter(exporter, "environment"): + _ = client.environment def test_reload_environment_span_ok( From 9e87a7ddf7075ed871e07c59c13460744bd7f59b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 5 Jan 2026 17:17:01 +0000 Subject: [PATCH 12/24] Get oidc config via property --- src/blueapi/cli/cli.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 995ad79f3a..600ce024ff 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -470,14 +470,13 @@ def login(obj: dict) -> None: print("Logged in") except Exception: client = BlueapiClient.from_config(config) - oidc_config = client.get_oidc_config() - if oidc_config is None: + if oidc := client.oidc_config: + auth = SessionManager( + oidc, cache_manager=SessionCacheManager(config.auth_token_path) + ) + auth.start_device_flow() + else: print("Server is not configured to use authentication!") - return - auth = SessionManager( - oidc_config, cache_manager=SessionCacheManager(config.auth_token_path) - ) - auth.start_device_flow() @main.command(name="logout") From af6eb2884e956457edfcc5a37506efae2ec344a6 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 5 Jan 2026 17:17:30 +0000 Subject: [PATCH 13/24] Add getitem support for plans --- src/blueapi/client/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 585481d197..9c383d381d 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -58,6 +58,9 @@ def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): continue setattr(self, name, plan) + def __getitem__(self, name: str) -> "Plan": + return self._cache[name] + def __getattr__(self, name: str) -> "Plan": raise AttributeError(f"No plan named '{name}' available") From 2adf0337b1f8d0d8b67898666648e9f0c2f5ed26 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 6 Jan 2026 10:19:16 +0000 Subject: [PATCH 14/24] Correct mocking in cli test --- tests/unit_tests/test_cli.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe80..ab4eddf9ea 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -198,15 +198,15 @@ def test_invalid_config_path_handling(runner: CliRunner): assert result.exit_code == 1 -@patch("blueapi.cli.cli.BlueapiClient.get_plans") +@patch("blueapi.cli.cli.BlueapiClient.plans") @patch("blueapi.cli.cli.OutputFormat.FULL.display") def test_options_via_env(mock_display, mock_plans, runner: CliRunner): result = runner.invoke( main, args=["controller", "plans"], env={"BLUEAPI_CONTROLLER_OUTPUT": "full"} ) - mock_plans.assert_called_once_with() - mock_display.assert_called_once_with(mock_plans.return_value) + mock_plans.__iter__.assert_called_once_with() + mock_display.assert_called_once_with(list(mock_plans)) assert result.exit_code == 0 From 484aec73efb4df0474d93399123f621bc1273cec Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 8 Jan 2026 17:03:22 +0000 Subject: [PATCH 15/24] Up the coverage --- src/blueapi/client/client.py | 4 +- tests/unit_tests/client/test_client.py | 191 ++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 4 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 9c383d381d..15ae4251e3 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -142,11 +142,11 @@ def help_text(self) -> str: @property def properties(self) -> set[str]: - return self.model.parameter_schema["properties"] + return self.model.parameter_schema.get("properties", {}).keys() @property def required(self) -> list[str]: - return self.model.parameter_schema["required"] + return self.model.parameter_schema.get("required", []) def _build_args(self, *args, **kwargs): log.info( diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index eac58cdf74..8d287d7e28 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -8,8 +8,16 @@ JsonObjectSpanExporter, asserting_span_exporter, ) - -from blueapi.client.client import BlueapiClient +from pydantic import HttpUrl + +from blueapi.client.client import ( + BlueapiClient, + DeviceCache, + DeviceRef, + MissingInstrumentSessionError, + Plan, + PlanCache, +) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError @@ -36,6 +44,19 @@ ] ) PLAN = PlanModel(name="foo") +FULL_PLAN = PlanModel( + name="foobar", + description="Description of plan foobar", + schema={ + "title": "foobar", + "description": "Model description of plan foobar", + "properties": { + "one": {}, + "two": {}, + }, + "required": ["one"], + }, +) DEVICES = DeviceResponse( devices=[ DeviceModel(name="foo", protocols=[]), @@ -106,12 +127,20 @@ def client_with_events(mock_rest: Mock, mock_events: MagicMock): return BlueapiClient(rest=mock_rest, events=mock_events) +def test_client_from_config(): + bc = BlueapiClient.from_config_file( + "tests/unit_tests/valid_example_config/client.yaml" + ) + assert bc._rest._config.url == HttpUrl("http://example.com:8082") + + def test_get_plans(client: BlueapiClient): assert PlanResponse(plans=[p.model for p in client.plans]) == PLANS def test_get_plan(client: BlueapiClient): assert client.plans.foo.model == PLAN + assert client.plans["foo"].model == PLAN def test_get_nonexistant_plan( @@ -462,6 +491,10 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_on_event.assert_called_once_with(COMPLETE_EVENT) +def test_get_oidc_config(client, mock_rest): + assert client.oidc_config == mock_rest.get_oidc_config() + + def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): with asserting_span_exporter(exporter, "plans"): _ = client.plans @@ -569,3 +602,157 @@ def test_cannot_run_task_span_ok( ): with asserting_span_exporter(exporter, "grun_task"): client.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + +def test_instrument_session_required(client): + with pytest.raises(MissingInstrumentSessionError): + _ = client.instrument_session + + +def test_setting_instrument_session(client): + # This looks like a completely pointless test but instrument_session is a + # property with some logic so it's not purely to get coverage up + client.instrument_session = "cm12345-4" + assert client.instrument_session == "cm12345-4" + + +def test_plan_cache_ignores_underscores(client): + cache = PlanCache(client, [PlanModel(name="_ignored"), PlanModel(name="used")]) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + +def test_device_cache_ignores_underscores(): + rest = Mock() + rest.get_devices.return_value = DeviceResponse( + devices=[ + DeviceModel(name="_ignored", protocols=[]), + ] + ) + cache = DeviceCache(rest) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + rest.get_devices.reset_mock() + with pytest.raises(AttributeError, match="_anything"): + _ = cache._anything + rest.get_device.assert_not_called() + + +def test_devices_are_cached(mock_rest): + cache = DeviceCache(mock_rest) + _ = cache.foo + mock_rest.get_device.assert_not_called() + _ = cache["foo"] + mock_rest.get_device.assert_not_called() + + +def test_device_repr(): + cache = Mock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + assert repr(dev) == "Device(foo)" + + +def test_device_ignores_underscores(): + cache = MagicMock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + with pytest.raises(AttributeError, match="_underscore"): + _ = dev._underscore + cache.__getitem__.assert_not_called() + + +def test_plan_help_text(client): + plan = Plan("foo", PlanModel(name="foo", description="help for foo"), client) + assert plan.help_text == "help for foo" + + +def test_plan_fallback_help_text(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + client, + ) + assert plan.help_text == "Plan foo(one, two=None)" + + +def test_plan_properties(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + client, + ) + + assert plan.properties == {"one", "two"} + assert plan.required == ["one"] + + +def test_plan_empty_fallback_help_text(client): + plan = Plan( + "foo", PlanModel(name="foo", schema={"properties": {}, "required": []}), client + ) + assert plan.help_text == "Plan foo()" + + +p = pytest.param + + +@pytest.mark.parametrize( + "args,kwargs,params", + [ + p((1,), {}, {"one": 1}, id="required_as_positional"), + p((), {"one": 7}, {"one": 7}, id="required_as_keyword"), + p((1,), {"two": 23}, {"one": 1, "two": 23}, id="all_as_mixed_args_kwargs"), + p((1, 2), {}, {"one": 1, "two": 2}, id="all_as_positional"), + p((), {"one": 21, "two": 42}, {"one": 21, "two": 42}, id="all_as_keyword"), + ], +) +def test_plan_param_mapping(args, kwargs, params): + client = Mock() + client.instrument_session = "cm12345-1" + plan = Plan( + FULL_PLAN.name, + FULL_PLAN, + client, + ) + + plan(*args, **kwargs) + client.run_task.assert_called_once_with( + TaskRequest(name="foobar", instrument_session="cm12345-1", params=params) + ) + + +@pytest.mark.parametrize( + "args,kwargs,msg", + [ + p((), {}, r"Missing argument\(s\) for \{'one'\}", id="missing_required"), + p((1,), {"one": 7}, "multiple values for one", id="duplicate_required"), + p((1, 2), {"two": 23}, "multiple values for two", id="duplicate_optional"), + p((1, 2, 3), {}, "too many arguments", id="too_many_args"), + p( + (), + {"unknown_key": 42}, + r"got unexpected arguments: \{'unknown_key'\}", + id="unknown_arg", + ), + ], +) +def test_plan_invalid_param_mapping(args, kwargs, msg): + client = Mock() + client.instrument_session = "cm12345-1" + plan = Plan( + FULL_PLAN.name, + FULL_PLAN, + client, + ) + + with pytest.raises(TypeError, match=msg): + plan(*args, **kwargs) + client.run_task.assert_not_called() From e964365909755bc6e9d73bf9ca726eb5dc22eb14 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 8 Jan 2026 17:30:05 +0000 Subject: [PATCH 16/24] Update system tests --- tests/system_tests/test_blueapi_system.py | 30 ++++++++++------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index fdb5372112..82e5862baf 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -17,8 +17,7 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError -from blueapi.client.rest import BlueskyRequestError -from blueapi.client.rest import BlueapiRestClient, UnknownPlanError +from blueapi.client.rest import BlueapiRestClient, BlueskyRequestError from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -170,18 +169,15 @@ def expected_devices() -> DeviceResponse: @pytest.fixture -def blueapi_client_get_methods() -> list[str]: +def blueapi_rest_client_get_methods() -> list[str]: # Get a list of methods that take only one argument (self) - # This will currently return - # ['get_plans', 'get_devices', 'get_state', 'get_all_tasks', - # 'get_active_task','get_environment','resume', 'stop','get_oidc_config'] return [ - method - for method in BlueapiClient.__dict__ - if callable(getattr(BlueapiClient, method)) - and not method.startswith("__") - and len(inspect.signature(getattr(BlueapiClient, method)).parameters) == 1 - and "self" in inspect.signature(getattr(BlueapiClient, method)).parameters + name + for name, method in BlueapiRestClient.__dict__.items() + if not name.startswith("__") + and callable(method) + and len(params := inspect.signature(method).parameters) == 1 + and "self" in params ] @@ -220,14 +216,14 @@ def reset_numtracker(server_config: ApplicationConfig): def test_cannot_access_endpoints( - client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] + client_without_auth: BlueapiClient, blueapi_rest_client_get_methods: list[str] ): - blueapi_client_get_methods.remove( + blueapi_rest_client_get_methods.remove( "get_oidc_config" ) # get_oidc_config can be accessed without auth - for get_method in blueapi_client_get_methods: + for get_method in blueapi_rest_client_get_methods: with pytest.raises(BlueskyRemoteControlError, match=r""): - getattr(client_without_auth, get_method)() + getattr(client_without_auth._rest, get_method)() def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): @@ -248,7 +244,7 @@ def test_get_plans(rest_client: BlueapiRestClient, expected_plans: PlanResponse) def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): for plan in expected_plans.plans: - assert getattr(client.plans, plan.name).model == plan + assert client.plans[plan.name].model == plan def test_get_non_existent_plan(rest_client: BlueapiRestClient): From 544df0a8d993321f97447027cb4b0170a43b83a8 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 9 Jan 2026 16:29:13 +0000 Subject: [PATCH 17/24] Add repr for caches --- src/blueapi/client/client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 15ae4251e3..e1fc656cdc 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -67,6 +67,9 @@ def __getattr__(self, name: str) -> "Plan": def __iter__(self): return iter(self._cache.values()) + def __repr__(self) -> str: + return f"PlanCache({len(self._cache)} plans)" + class DeviceCache: def __init__(self, rest: BlueapiRestClient): @@ -101,6 +104,9 @@ def __getattr__(self, name: str) -> "DeviceRef": def __iter__(self): return iter(self._cache.values()) + def __repr__(self) -> str: + return f"DeviceCache({len(self._cache)} devices)" + class DeviceRef(str): model: DeviceModel From 3147454a29d1574a97b152a48ee5b5b5e882e6f6 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 9 Jan 2026 16:44:36 +0000 Subject: [PATCH 18/24] Add ServiceUnavailableError to wrap requests errors --- src/blueapi/cli/cli.py | 4 ++-- src/blueapi/client/rest.py | 23 +++++++++++++++-------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 600ce024ff..77e63c2ef0 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,7 +17,6 @@ from click.exceptions import ClickException from observability_utils.tracing import setup_tracing from pydantic import ValidationError -from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat @@ -26,6 +25,7 @@ from blueapi.client.rest import ( BlueskyRemoteControlError, InvalidParametersError, + ServiceUnavailableError, UnauthorisedAccessError, UnknownPlanError, ) @@ -183,7 +183,7 @@ def check_connection(func: Callable[P, T]) -> Callable[P, T]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return func(*args, **kwargs) - except ConnectionError as ce: + except ServiceUnavailableError as ce: raise ClickException( "Failed to establish connection to blueapi server." ) from ce diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 3ff119449e..25f24e3b2b 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -252,14 +252,17 @@ def _request_and_deserialize( url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API carr = get_context_propagator() - response = requests.request( - method, - url, - json=data, - params=params, - headers=carr, - auth=JWTAuth(self._session_manager), - ) + try: + response = requests.request( + method, + url, + json=data, + params=params, + headers=carr, + auth=JWTAuth(self._session_manager), + ) + except requests.exceptions.ConnectionError as ce: + raise ServiceUnavailableError() from ce exception = get_exception(response) if exception is not None: raise exception @@ -289,3 +292,7 @@ def __getattr__(name: str): ) return rename raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class ServiceUnavailableError(Exception): + pass From 51ed822f4f2c26e2c3acb619e1b521ed3776ed47 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 9 Jan 2026 16:45:28 +0000 Subject: [PATCH 19/24] Create re-usable session for rest calls --- src/blueapi/client/rest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 25f24e3b2b..b6cebcd099 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -136,6 +136,7 @@ def _create_task_exceptions(response: requests.Response) -> Exception | None: class BlueapiRestClient: _config: RestConfig + _pool: requests.Session def __init__( self, @@ -144,6 +145,7 @@ def __init__( ) -> None: self._config = config or RestConfig() self._session_manager = session_manager + self._pool = requests.Session() def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -253,7 +255,7 @@ def _request_and_deserialize( # Get the trace context to propagate to the REST API carr = get_context_propagator() try: - response = requests.request( + response = self._pool.request( method, url, json=data, From 9d4ad7e98845898c8ac99ad30b836a37737e362f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 13 Jan 2026 15:49:20 +0000 Subject: [PATCH 20/24] Change mock in tests --- tests/unit_tests/client/test_rest.py | 2 +- tests/unit_tests/test_cli.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index c8fce9d101..2ddcdd3800 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -45,7 +45,7 @@ def rest_with_auth(oidc_config: OIDCConfig, tmp_path) -> BlueapiRestClient: (500, BlueskyRemoteControlError), ], ) -@patch("blueapi.client.rest.requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_rest_error_code( mock_request: Mock, rest: BlueapiRestClient, diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index ab4eddf9ea..17539db600 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -110,7 +110,7 @@ def test_runs_with_umask_002( mock_umask.assert_called_once_with(0o002) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_connection_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -120,7 +120,7 @@ def test_connection_error_caught_by_wrapper_func( assert result.output == "Error: Failed to establish connection to blueapi server.\n" -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_authentication_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -133,7 +133,7 @@ def test_authentication_error_caught_by_wrapper_func( ) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_remote_error_raised_by_wrapper_func(mock_requests: Mock, runner: CliRunner): mock_requests.side_effect = BlueskyRemoteControlError("Response [450]") From 15c8ef050f20bc9b16c2257cca7b88331141ac0f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 13 Jan 2026 16:34:28 +0000 Subject: [PATCH 21/24] Catch correct exception in system tests --- tests/system_tests/test_blueapi_system.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 82e5862baf..dc82c394ee 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -9,7 +9,6 @@ import requests from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter -from requests.exceptions import ConnectionError from scanspec.specs import Line from blueapi.client.client import ( @@ -17,7 +16,11 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError -from blueapi.client.rest import BlueapiRestClient, BlueskyRequestError +from blueapi.client.rest import ( + BlueapiRestClient, + BlueskyRequestError, + ServiceUnavailableError, +) from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -132,7 +135,7 @@ def wait_for_server(client: BlueapiClient): try: _ = client.environment return - except ConnectionError: + except ServiceUnavailableError: ... time.sleep(0.5) raise TimeoutError("No connection to the blueapi server") From 0a3edcd708fcfd269427cbdbbb4dd89a1601b80e Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 13 Jan 2026 17:32:00 +0000 Subject: [PATCH 22/24] Up the coverage --- src/blueapi/cli/cli.py | 6 +++--- src/blueapi/cli/format.py | 6 ------ tests/unit_tests/client/test_client.py | 8 ++++++++ tests/unit_tests/test_cli.py | 17 +++++++++++++---- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 77e63c2ef0..2ae736cff1 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -36,7 +36,7 @@ from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager -from blueapi.service.model import SourceInfo, TaskRequest +from blueapi.service.model import DeviceResponse, PlanResponse, SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent from .scratch import setup_scratch @@ -204,7 +204,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display([p.model for p in client.plans]) + obj["fmt"].display(PlanResponse(plans=[p.model for p in client.plans])) @controller.command(name="devices") @@ -213,7 +213,7 @@ def get_plans(obj: dict) -> None: def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display([dev.model for dev in client.devices]) + obj["fmt"].display(DeviceResponse(devices=[dev.model for dev in client.devices])) @controller.command(name="listen") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 5d14edf0f7..7ce4fffbb0 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -98,9 +98,6 @@ def display_full(obj: Any, stream: Stream): case BaseModel(): print(obj.__class__.__name__, end="") print(fmt_dict(obj.model_dump())) - case list(): - for item in obj: - display_full(item, stream) case other: FALLBACK(other, stream=stream) @@ -175,9 +172,6 @@ def display_compact(obj: Any, stream: Stream): extra += " (Scratch)" print(f"- {package.name} @ ({package.version}){extra}") - case list(): - for item in obj: - display_compact(item, stream) case other: FALLBACK(other, stream=stream) diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 8d287d7e28..a9392dddbf 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -622,6 +622,10 @@ def test_plan_cache_ignores_underscores(client): _ = cache._ignored +def test_plan_cache_repr(client): + assert repr(client.plans) == "PlanCache(2 plans)" + + def test_device_cache_ignores_underscores(): rest = Mock() rest.get_devices.return_value = DeviceResponse( @@ -647,6 +651,10 @@ def test_devices_are_cached(mock_rest): mock_rest.get_device.assert_not_called() +def test_device_cache_repr(client): + assert repr(client.devices) == "DeviceCache(2 devices)" + + def test_device_repr(): cache = Mock() model = Mock() diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 17539db600..7210ec2bbc 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -206,7 +206,7 @@ def test_options_via_env(mock_display, mock_plans, runner: CliRunner): ) mock_plans.__iter__.assert_called_once_with() - mock_display.assert_called_once_with(list(mock_plans)) + mock_display.assert_called_once_with(PlanResponse(plans=list(mock_plans))) assert result.exit_code == 0 @@ -493,9 +493,7 @@ def test_valid_stomp_config_for_listener( @responses.activate -def test_get_env( - runner: CliRunner, -): +def test_get_env(runner: CliRunner): environment_id = uuid.uuid4() responses.add( responses.GET, @@ -514,6 +512,17 @@ def test_get_env( ) +@responses.activate +def test_get_state(runner: CliRunner): + responses.add( + responses.GET, "http://localhost:8000/worker/state", json="IDLE", status=200 + ) + state = runner.invoke(main, ["controller", "state"]) + print(state.stderr) + assert state.exit_code == 0 + assert state.output == "IDLE\n" + + @responses.activate(assert_all_requests_are_fired=True) @patch("blueapi.client.client.time.sleep", return_value=None) def test_reset_env_client_behavior( From 93d5d8514e5ebe28e58bf4598d6de4225bd5144b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 14 Jan 2026 13:37:55 +0000 Subject: [PATCH 23/24] Add callback support --- src/blueapi/client/client.py | 29 ++++++++++ tests/unit_tests/client/test_client.py | 79 ++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index e1fc656cdc..e6f1e83e39 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,7 @@ +import itertools import logging import time +from collections.abc import Iterable from concurrent.futures import Future from functools import cached_property from itertools import chain @@ -196,6 +198,8 @@ class BlueapiClient: _rest: BlueapiRestClient _events: EventBusClient | None _instrument_session: str | None = None + _callbacks: dict[int, OnAnyEvent] + _callback_id: itertools.count def __init__( self, @@ -204,6 +208,8 @@ def __init__( ): self._rest = rest self._events = events + self._callbacks = {} + self._callback_id = itertools.count() @cached_property @start_as_current_span(TRACER) @@ -258,6 +264,22 @@ def instrument_session(self, session: str): log.debug("Setting instrument_session to %s", session) self._instrument_session = session + def with_instrument_session(self, session: str) -> Self: + self.instrument_session = session + return self + + def add_callback(self, callback: OnAnyEvent) -> int: + cb_id = next(self._callback_id) + self._callbacks[cb_id] = callback + return cb_id + + def remove_callback(self, id: int): + self._callbacks.pop(id) + + @property + def callbacks(self) -> Iterable[OnAnyEvent]: + return self._callbacks.values() + @property @start_as_current_span(TRACER) def state(self) -> WorkerState: @@ -355,6 +377,13 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error( + f"Callback ({cb}) failed for event: {event}", exc_info=e + ) if isinstance(event, WorkerEvent) and ( (event.is_complete()) and (ctx.correlation_id == task_id) ): diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index a9392dddbf..f4601a2fba 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -764,3 +764,82 @@ def test_plan_invalid_param_mapping(args, kwargs, msg): with pytest.raises(TypeError, match=msg): plan(*args, **kwargs) client.run_task.assert_not_called() + + +def test_adding_removing_callback(client): + def callback(*a, **kw): + pass + + cb_id = client.add_callback(callback) + assert len(client.callbacks) == 1 + client.remove_callback(cb_id) + assert len(client.callbacks) == 0 + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="foo", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="foo"), + DataEvent(name="start", doc={}, task_id="0000-1111"), + ], +) +def test_client_callbacks( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + callback = Mock() + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert callback.mock_calls == [call(test_event), call(COMPLETE_EVENT)] + + +def test_client_callback_failures( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + failing_callback = Mock(side_effect=ValueError("Broken callback")) + callback = Mock() + client_with_events.add_callback(failing_callback) + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + evt = DataEvent(name="start", doc={}, task_id="foo") + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(evt, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert failing_callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] + assert callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] From ba2287b59dbc0106a908711a194e17e3099539a1 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 14 Jan 2026 13:46:42 +0000 Subject: [PATCH 24/24] Test fluent instrument_session setter --- tests/unit_tests/client/test_client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index f4601a2fba..98aad78718 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -616,6 +616,12 @@ def test_setting_instrument_session(client): assert client.instrument_session == "cm12345-4" +def test_fluent_instrument_session_setter(client): + client2 = client.with_instrument_session("cm12345-3") + assert client is client2 + assert client.instrument_session == "cm12345-3" + + def test_plan_cache_ignores_underscores(client): cache = PlanCache(client, [PlanModel(name="_ignored"), PlanModel(name="used")]) with pytest.raises(AttributeError, match="_ignored"):