diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c39ff13dff..2ae736cff1 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,7 +17,6 @@ from click.exceptions import ClickException from observability_utils.tracing import setup_tracing from pydantic import ValidationError -from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat @@ -26,6 +25,7 @@ from blueapi.client.rest import ( BlueskyRemoteControlError, InvalidParametersError, + ServiceUnavailableError, UnauthorisedAccessError, UnknownPlanError, ) @@ -36,7 +36,7 @@ from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager -from blueapi.service.model import SourceInfo, TaskRequest +from blueapi.service.model import DeviceResponse, PlanResponse, SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent from .scratch import setup_scratch @@ -183,7 +183,7 @@ def check_connection(func: Callable[P, T]) -> Callable[P, T]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return func(*args, **kwargs) - except ConnectionError as ce: + except ServiceUnavailableError as ce: raise ClickException( "Failed to establish connection to blueapi server." ) from ce @@ -204,7 +204,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_plans()) + obj["fmt"].display(PlanResponse(plans=[p.model for p in client.plans])) @controller.command(name="devices") @@ -213,7 +213,7 @@ def get_plans(obj: dict) -> None: def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_devices()) + obj["fmt"].display(DeviceResponse(devices=[dev.model for dev in client.devices])) @controller.command(name="listen") @@ -345,7 +345,7 @@ def get_state(obj: dict) -> None: """Print the current state of the worker""" client: BlueapiClient = obj["client"] - print(client.get_state().name) + print(client.state.name) @controller.command(name="pause") @@ -428,7 +428,7 @@ def env( status = client.reload_environment(timeout=timeout) print("Environment is initialized") else: - status = client.get_environment() + status = client.environment print(status) @@ -470,14 +470,13 @@ def login(obj: dict) -> None: print("Logged in") except Exception: client = BlueapiClient.from_config(config) - oidc_config = client.get_oidc_config() - if oidc_config is None: + if oidc := client.oidc_config: + auth = SessionManager( + oidc, cache_manager=SessionCacheManager(config.auth_token_path) + ) + auth.start_device_flow() + else: print("Server is not configured to use authentication!") - return - auth = SessionManager( - oidc_config, cache_manager=SessionCacheManager(config.auth_token_path) - ) - auth.start_device_flow() @main.command(name="logout") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 490e57cfa5..7ce4fffbb0 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -12,7 +12,9 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( + DeviceModel, DeviceResponse, + PlanModel, PlanResponse, PythonEnvironmentResponse, SourceInfo, @@ -54,17 +56,21 @@ def display_full(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc).strip(), " ")) - if schema := plan.parameter_schema: - print(" Schema") - print(indent(json.dumps(schema, indent=2), " ")) + display_full(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc).strip(), " ")) + if schema: + print(" Schema") + print(indent(json.dumps(schema, indent=2), " ")) case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - for proto in dev.protocols: - print(f" {proto}") + display_full(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + for proto in protocols: + print(f" {proto}") case DataEvent(name=name, doc=doc): print(f"{name.title()}:{fmt_dict(doc)}") case WorkerEvent(state=st, task_status=task): @@ -100,11 +106,13 @@ def display_json(obj: Any, stream: Stream): print = partial(builtins.print, file=stream) match obj: case PlanResponse(plans=plans): - print(json.dumps([p.model_dump() for p in plans], indent=2)) + display_json(plans, stream) case DeviceResponse(devices=devices): - print(json.dumps([d.model_dump() for d in devices], indent=2)) + display_json(devices, stream) case BaseModel(): print(json.dumps(obj.model_dump())) + case list(): + print(json.dumps([it.model_dump() for it in obj], indent=2)) case _: print(json.dumps(obj)) @@ -114,26 +122,30 @@ def display_compact(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) - if schema := plan.parameter_schema: - print(" Args") - for arg, spec in schema.get("properties", {}).items(): - req = arg in schema.get("required", {}) - print(f" {arg}={_describe_type(spec, req)}") + display_compact(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) + if schema: + print(" Args") + for arg, spec in schema.get("properties", {}).items(): + req = arg in schema.get("required", {}) + print(f" {arg}={_describe_type(spec, req)}") case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - print( - indent( - textwrap.fill( - ", ".join(str(proto) for proto in dev.protocols), - 80, - ), - " ", - ) + display_compact(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + print( + indent( + textwrap.fill( + ", ".join(str(proto) for proto in protocols), + 80, + ), + " ", ) + ) case DataEvent(name=name): print(f"Data Event: {name}") case WorkerEvent(state=state): diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a9..e6f1e83e39 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,12 @@ +import itertools +import logging import time +from collections.abc import Iterable from concurrent.futures import Future +from functools import cached_property +from itertools import chain +from pathlib import Path +from typing import Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -8,24 +15,25 @@ start_as_current_span, ) -from blueapi.config import ApplicationConfig, MissingStompConfigurationError +from blueapi.config import ( + ApplicationConfig, + ConfigLoader, + MissingStompConfigurationError, +) from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, - DeviceResponse, EnvironmentResponse, OIDCConfig, PlanModel, - PlanResponse, PythonEnvironmentResponse, SourceInfo, TaskRequest, TaskResponse, - TasksListResponse, WorkerTask, ) -from blueapi.worker import TrackableTask, WorkerEvent, WorkerState +from blueapi.worker import WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent @@ -34,11 +42,164 @@ TRACER = get_tracer("client") +log = logging.getLogger(__name__) + + +class MissingInstrumentSessionError(Exception): + pass + + +class PlanCache: + def __init__(self, client: "BlueapiClient", plans: list[PlanModel]): + self._cache = { + model.name: Plan(name=model.name, model=model, client=client) + for model in plans + } + for name, plan in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, plan) + + def __getitem__(self, name: str) -> "Plan": + return self._cache[name] + + def __getattr__(self, name: str) -> "Plan": + raise AttributeError(f"No plan named '{name}' available") + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"PlanCache({len(self._cache)} plans)" + + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient): + self._rest = rest + self._cache = { + model.name: DeviceRef(name=model.name, cache=self, model=model) + for model in rest.get_devices().devices + } + for name, device in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, device) + + def __getitem__(self, name: str) -> "DeviceRef": + if dev := self._cache.get(name): + return dev + try: + model = self._rest.get_device(name) + device = DeviceRef(name=name, cache=self, model=model) + self._cache[name] = device + setattr(self, model.name, device) + return device + except KeyError: + pass + raise AttributeError(f"No device named '{name}' available") + + def __getattr__(self, name: str) -> "DeviceRef": + if name.startswith("_"): + return super().__getattribute__(name) + return self[name] + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"DeviceCache({len(self._cache)} devices)" + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): + instance = super().__new__(cls, name) + instance.model = model + instance._cache = cache + return instance + + def __getattr__(self, name) -> "DeviceRef": + if name.startswith("_"): + raise AttributeError(f"No child device named {name}") + return self._cache[f"{self}.{name}"] + + def __repr__(self): + return f"Device({self})" + + +class Plan: + def __init__(self, name, model: PlanModel, client: "BlueapiClient"): + self.name = name + self.model = model + self._client = client + self.__doc__ = model.description + + def __call__(self, *args, **kwargs): + req = TaskRequest( + name=self.name, + params=self._build_args(*args, **kwargs), + instrument_session=self._client.instrument_session, + ) + self._client.run_task(req) + + @property + def help_text(self) -> str: + return self.model.description or f"Plan {self!r}" + + @property + def properties(self) -> set[str]: + return self.model.parameter_schema.get("properties", {}).keys() + + @property + def required(self) -> list[str]: + return self.model.parameter_schema.get("required", []) + + def _build_args(self, *args, **kwargs): + log.info( + "Building args for %s, using %s and %s", + "[" + ",".join(self.properties) + "]", + args, + kwargs, + ) + + if len(args) > len(self.properties): + raise TypeError(f"{self.name} got too many arguments") + if extra := {k for k in kwargs if k not in self.properties}: + raise TypeError(f"{self.name} got unexpected arguments: {extra}") + + params = {} + # Initially fill parameters using positional args assuming the order + # from the parameter_schema + for req, arg in zip(self.properties, args, strict=False): + params[req] = arg + + # Then append any values given via kwargs + for key, value in kwargs.items(): + # If we've already assumed a positional arg was this value, bail out + if key in params: + raise TypeError(f"{self.name} got multiple values for {key}") + params[key] = value + + if missing := {k for k in self.required if k not in params}: + raise TypeError(f"Missing argument(s) for {missing}") + return params + + def __repr__(self): + opts = [p for p in self.properties if p not in self.required] + params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) + return f"{self.name}({params})" + + class BlueapiClient: """Unified client for controlling blueapi""" _rest: BlueapiRestClient _events: EventBusClient | None + _instrument_session: str | None = None + _callbacks: dict[int, OnAnyEvent] + _callback_id: itertools.count def __init__( self, @@ -47,9 +208,30 @@ def __init__( ): self._rest = rest self._events = events + self._callbacks = {} + self._callback_id = itertools.count() + + @cached_property + @start_as_current_span(TRACER) + def plans(self) -> PlanCache: + return PlanCache(self, self._rest.get_plans().plans) + + @cached_property + @start_as_current_span(TRACER) + def devices(self) -> DeviceCache: + return DeviceCache(self._rest) @classmethod - def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": + def from_config_file(cls, config_file: str) -> Self: + conf = ConfigLoader(ApplicationConfig) + conf.use_values_from_yaml(Path(config_file)) + return cls.from_config(conf.load()) + + @classmethod + def from_config( + cls, + config: ApplicationConfig, + ) -> Self: session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) @@ -71,56 +253,36 @@ def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": else: return cls(rest) - @start_as_current_span(TRACER) - def get_plans(self) -> PlanResponse: - """ - List plans available - - Returns: - PlanResponse: Plans that can be run - """ - return self._rest.get_plans() - - @start_as_current_span(TRACER, "name") - def get_plan(self, name: str) -> PlanModel: - """ - Get details of a single plan - - Args: - name: Plan name - - Returns: - PlanModel: Details of the plan if found - """ - return self._rest.get_plan(name) - - @start_as_current_span(TRACER) - def get_devices(self) -> DeviceResponse: - """ - List devices available + @property + def instrument_session(self) -> str: + if self._instrument_session is None: + raise MissingInstrumentSessionError() + return self._instrument_session - Returns: - DeviceResponse: Devices that can be used in plans - """ + @instrument_session.setter + def instrument_session(self, session: str): + log.debug("Setting instrument_session to %s", session) + self._instrument_session = session - return self._rest.get_devices() + def with_instrument_session(self, session: str) -> Self: + self.instrument_session = session + return self - @start_as_current_span(TRACER, "name") - def get_device(self, name: str) -> DeviceModel: - """ - Get details of a single device + def add_callback(self, callback: OnAnyEvent) -> int: + cb_id = next(self._callback_id) + self._callbacks[cb_id] = callback + return cb_id - Args: - name: Device name + def remove_callback(self, id: int): + self._callbacks.pop(id) - Returns: - DeviceModel: Details of the device if found - """ - - return self._rest.get_device(name) + @property + def callbacks(self) -> Iterable[OnAnyEvent]: + return self._callbacks.values() + @property @start_as_current_span(TRACER) - def get_state(self) -> WorkerState: + def state(self) -> WorkerState: """ Get current state of the blueapi worker @@ -158,33 +320,9 @@ def resume(self) -> WorkerState: return self._rest.set_state(WorkerState.RUNNING, defer=False) - @start_as_current_span(TRACER, "task_id") - def get_task(self, task_id: str) -> TrackableTask: - """ - Get a task stored by the worker - - Args: - task_id: Unique ID for the task - - Returns: - TrackableTask: Task details - """ - assert task_id, "Task ID not provided!" - return self._rest.get_task(task_id) - + @property @start_as_current_span(TRACER) - def get_all_tasks(self) -> TasksListResponse: - """ - Get a list of all task stored by the worker - - Returns: - TasksListResponse: List of all Trackable Task - """ - - return self._rest.get_all_tasks() - - @start_as_current_span(TRACER) - def get_active_task(self) -> WorkerTask: + def active_task(self) -> WorkerTask: """ Get the currently active task, if any @@ -221,7 +359,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self._rest.create_task(task) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -239,6 +377,13 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error( + f"Callback ({cb}) failed for event: {event}", exc_info=e + ) if isinstance(event, WorkerEvent) and ( (event.is_complete()) and (ctx.correlation_id == task_id) ): @@ -255,7 +400,7 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: with self._events: self._events.subscribe_to_all_events(inner_on_event) - self.start_task(WorkerTask(task_id=task_id)) + self._rest.update_worker_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) @start_as_current_span(TRACER, "task") @@ -271,8 +416,10 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) - worker_response = self.start_task(WorkerTask(task_id=response.task_id)) + response = self._rest.create_task(task) + worker_response = self._rest.update_worker_task( + WorkerTask(task_id=response.task_id) + ) if worker_response.task_id == response.task_id: return response else: @@ -281,48 +428,6 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - @start_as_current_span(TRACER, "task") - def create_task(self, task: TaskRequest) -> TaskResponse: - """ - Create a new task, does not start execution - - Args: - task: Request object for task to create on the worker - - Returns: - TaskResponse: Acknowledgement of request - """ - - return self._rest.create_task(task) - - @start_as_current_span(TRACER) - def clear_task(self, task_id: str) -> TaskResponse: - """ - Delete a stored task on the worker - - Args: - task_id: ID for the task - - Returns: - TaskResponse: Acknowledgement of request - """ - - return self._rest.clear_task(task_id) - - @start_as_current_span(TRACER, "task") - def start_task(self, task: WorkerTask) -> WorkerTask: - """ - Instruct the worker to start a stored task immediately - - Args: - task: WorkerTask to start - - Returns: - WorkerTask: Acknowledgement of request - """ - - return self._rest.update_worker_task(task) - @start_as_current_span(TRACER, "reason") def abort(self, reason: str | None = None) -> WorkerState: """ @@ -358,15 +463,10 @@ def stop(self) -> WorkerState: return self._rest.cancel_current_task(WorkerState.STOPPING) + @property @start_as_current_span(TRACER) - def get_environment(self) -> EnvironmentResponse: - """ - Get details of the worker environment - - Returns: - EnvironmentResponse: Details of the worker - environment. - """ + def environment(self) -> EnvironmentResponse: + """Details of the worker environment""" return self._rest.get_environment() @@ -433,14 +533,10 @@ def _wait_for_reload( "seconds, a server restart is recommended" ) + @property @start_as_current_span(TRACER) - def get_oidc_config(self) -> OIDCConfig | None: - """ - Get oidc config from the server - - Returns: - OIDCConfig: Details of the oidc Config - """ + def oidc_config(self) -> OIDCConfig | None: + """OIDC config from the server""" return self._rest.get_oidc_config() diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 3ff119449e..b6cebcd099 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -136,6 +136,7 @@ def _create_task_exceptions(response: requests.Response) -> Exception | None: class BlueapiRestClient: _config: RestConfig + _pool: requests.Session def __init__( self, @@ -144,6 +145,7 @@ def __init__( ) -> None: self._config = config or RestConfig() self._session_manager = session_manager + self._pool = requests.Session() def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -252,14 +254,17 @@ def _request_and_deserialize( url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API carr = get_context_propagator() - response = requests.request( - method, - url, - json=data, - params=params, - headers=carr, - auth=JWTAuth(self._session_manager), - ) + try: + response = self._pool.request( + method, + url, + json=data, + params=params, + headers=carr, + auth=JWTAuth(self._session_manager), + ) + except requests.exceptions.ConnectionError as ce: + raise ServiceUnavailableError() from ce exception = get_exception(response) if exception is not None: raise exception @@ -289,3 +294,7 @@ def __getattr__(name: str): ) return rename raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class ServiceUnavailableError(Exception): + pass diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 0271331a6d..dc82c394ee 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -9,7 +9,6 @@ import requests from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter -from requests.exceptions import ConnectionError from scanspec.specs import Line from blueapi.client.client import ( @@ -17,7 +16,11 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError -from blueapi.client.rest import BlueskyRequestError +from blueapi.client.rest import ( + BlueapiRestClient, + BlueskyRequestError, + ServiceUnavailableError, +) from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -130,9 +133,9 @@ def client_with_stomp() -> Generator[BlueapiClient]: def wait_for_server(client: BlueapiClient): for _ in range(20): try: - client.get_environment() + _ = client.environment return - except ConnectionError: + except ServiceUnavailableError: ... time.sleep(0.5) raise TimeoutError("No connection to the blueapi server") @@ -149,6 +152,11 @@ def client() -> Generator[BlueapiClient]: yield BlueapiClient.from_config(config=ApplicationConfig()) +@pytest.fixture +def rest_client(client: BlueapiClient) -> BlueapiRestClient: + return client._rest + + @pytest.fixture def expected_plans() -> PlanResponse: return TypeAdapter(PlanResponse).validate_json( @@ -164,25 +172,22 @@ def expected_devices() -> DeviceResponse: @pytest.fixture -def blueapi_client_get_methods() -> list[str]: +def blueapi_rest_client_get_methods() -> list[str]: # Get a list of methods that take only one argument (self) - # This will currently return - # ['get_plans', 'get_devices', 'get_state', 'get_all_tasks', - # 'get_active_task','get_environment','resume', 'stop','get_oidc_config'] return [ - method - for method in BlueapiClient.__dict__ - if callable(getattr(BlueapiClient, method)) - and not method.startswith("__") - and len(inspect.signature(getattr(BlueapiClient, method)).parameters) == 1 - and "self" in inspect.signature(getattr(BlueapiClient, method)).parameters + name + for name, method in BlueapiRestClient.__dict__.items() + if not name.startswith("__") + and callable(method) + and len(params := inspect.signature(method).parameters) == 1 + and "self" in params ] @pytest.fixture(autouse=True) -def clean_existing_tasks(client: BlueapiClient): - for task in client.get_all_tasks().tasks: - client.clear_task(task.task_id) +def clean_existing_tasks(rest_client: BlueapiRestClient): + for task in rest_client.get_all_tasks().tasks: + rest_client.clear_task(task.task_id) yield @@ -214,26 +219,26 @@ def reset_numtracker(server_config: ApplicationConfig): def test_cannot_access_endpoints( - client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] + client_without_auth: BlueapiClient, blueapi_rest_client_get_methods: list[str] ): - blueapi_client_get_methods.remove( + blueapi_rest_client_get_methods.remove( "get_oidc_config" ) # get_oidc_config can be accessed without auth - for get_method in blueapi_client_get_methods: + for get_method in blueapi_rest_client_get_methods: with pytest.raises(BlueskyRemoteControlError, match=r""): - getattr(client_without_auth, get_method)() + getattr(client_without_auth._rest, get_method)() def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): - assert client_without_auth.get_oidc_config() == OIDCConfig( + assert client_without_auth.oidc_config == OIDCConfig( well_known_url=KEYCLOAK_BASE_URL + "realms/master/.well-known/openid-configuration", client_id="ixx-cli-blueapi", ) -def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): - retrieved_plans = client.get_plans() +def test_get_plans(rest_client: BlueapiRestClient, expected_plans: PlanResponse): + retrieved_plans = rest_client.get_plans() retrieved_plans.plans.sort(key=lambda x: x.name) expected_plans.plans.sort(key=lambda x: x.name) @@ -242,40 +247,52 @@ def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): for plan in expected_plans.plans: - assert client.get_plan(plan.name) == plan + assert client.plans[plan.name].model == plan -def test_get_non_existent_plan(client: BlueapiClient): +def test_get_non_existent_plan(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_plan("Not exists") + rest_client.get_plan("Not exists") + +def test_client_non_existent_plan(client: BlueapiClient): + with pytest.raises(AttributeError, match="No plan named 'missing' available"): + _ = client.plans.missing -def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): - retrieved_devices = client.get_devices() + +def test_get_devices(rest_client: BlueapiRestClient, expected_devices: DeviceResponse): + retrieved_devices = rest_client.get_devices() retrieved_devices.devices.sort(key=lambda x: x.name) expected_devices.devices.sort(key=lambda x: x.name) assert retrieved_devices == expected_devices -def test_get_device_by_name(client: BlueapiClient, expected_devices: DeviceResponse): +def test_get_device_by_name( + rest_client: BlueapiRestClient, expected_devices: DeviceResponse +): for device in expected_devices.devices: - assert client.get_device(device.name) == device + assert rest_client.get_device(device.name) == device -def test_get_non_existent_device(client: BlueapiClient): +def test_get_non_existent_device(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_device("Not exists") + rest_client.get_device("Not exists") + + +def test_client_non_existent_device(client: BlueapiClient): + with pytest.raises(AttributeError, match="No device named 'missing' available"): + _ = client.devices.missing -def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) - client.clear_task(create_task.task_id) +def test_create_task_and_delete_task_by_id(rest_client: BlueapiRestClient): + create_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.clear_task(create_task.task_id) -def test_instrument_session_propagated(client: BlueapiClient): - response = client.create_task(_SIMPLE_TASK) - trackable_task = client.get_task(response.task_id) +def test_instrument_session_propagated(rest_client: BlueapiRestClient): + response = rest_client.create_task(_SIMPLE_TASK) + trackable_task = rest_client.get_task(response.task_id) assert trackable_task.task.metadata == { "instrument_session": AUTHORIZED_INSTRUMENT_SESSION, "tiled_access_tags": [ @@ -284,9 +301,9 @@ def test_instrument_session_propagated(client: BlueapiClient): } -def test_create_task_validation_error(client: BlueapiClient): +def test_create_task_validation_error(rest_client: BlueapiRestClient): with pytest.raises(BlueskyRequestError, match="Internal Server Error"): - client.create_task( + rest_client.create_task( TaskRequest( name="Not-exists", params={"Not-exists": 0.0}, @@ -295,26 +312,26 @@ def test_create_task_validation_error(client: BlueapiClient): ) -def test_get_all_tasks(client: BlueapiClient): +def test_get_all_tasks(rest_client: BlueapiRestClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = rest_client.create_task(task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] - task_list = client.get_all_tasks() + task_list = rest_client.get_all_tasks() for trackable_task in task_list.tasks: assert trackable_task.task_id in task_ids assert trackable_task.is_complete is False and trackable_task.is_pending is True for task_id in task_ids: - client.clear_task(task_id) + rest_client.clear_task(task_id) -def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) +def test_get_task_by_id(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) - get_task = client.get_task(created_task.task_id) + get_task = rest_client.get_task(created_task.task_id) assert ( get_task.task_id == created_task.task_id and get_task.is_pending @@ -322,45 +339,45 @@ def test_get_task_by_id(client: BlueapiClient): and len(get_task.errors) == 0 ) - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_get_non_existent_task(client: BlueapiClient): +def test_get_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_task("Not-exists") + rest_client.get_task("Not-exists") -def test_delete_non_existent_task(client: BlueapiClient): +def test_delete_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.clear_task("Not-exists") + rest_client.clear_task("Not-exists") -def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) - client.start_task(WorkerTask(task_id=created_task.task_id)) - active_task = client.get_active_task() +def test_put_worker_task(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.update_worker_task(WorkerTask(task_id=created_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == created_task.task_id - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) +def test_put_worker_task_fails_if_not_idle(rest_client: BlueapiRestClient): + small_task = rest_client.create_task(_SIMPLE_TASK) + long_task = rest_client.create_task(_LONG_TASK) - client.start_task(WorkerTask(task_id=long_task.task_id)) - active_task = client.get_active_task() + rest_client.update_worker_task(WorkerTask(task_id=long_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == long_task.task_id with pytest.raises(BlueskyRemoteControlError) as exception: - client.start_task(WorkerTask(task_id=small_task.task_id)) + rest_client.update_worker_task(WorkerTask(task_id=small_task.task_id)) assert "" in str(exception) - client.abort() - client.clear_task(small_task.task_id) - client.clear_task(long_task.task_id) + rest_client.cancel_current_task(WorkerState.ABORTING) + rest_client.clear_task(small_task.task_id) + rest_client.clear_task(long_task.task_id) def test_get_worker_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE + assert client.state == WorkerState.IDLE def test_set_state_transition_error(client: BlueapiClient): @@ -372,10 +389,10 @@ def test_set_state_transition_error(client: BlueapiClient): assert "" in str(exception) -def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) - task_by_pending = client.get_all_tasks() +def test_get_task_by_status(rest_client: BlueapiRestClient): + task_1 = rest_client.create_task(_SIMPLE_TASK) + task_2 = rest_client.create_task(_SIMPLE_TASK) + task_by_pending = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) assert len(task_by_pending.tasks) == 2 @@ -384,13 +401,13 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is False and trackable_task.is_pending is True - client.start_task(WorkerTask(task_id=task_1.task_id)) - while not client.get_task(task_1.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_1.task_id)) + while not rest_client.get_task(task_1.task_id).is_complete: time.sleep(0.1) - client.start_task(WorkerTask(task_id=task_2.task_id)) - while not client.get_task(task_2.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_2.task_id)) + while not rest_client.get_task(task_2.task_id).is_complete: time.sleep(0.1) - task_by_completed = client.get_all_tasks() + task_by_completed = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.COMPLETE) assert len(task_by_completed.tasks) == 2 @@ -399,8 +416,8 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is True and trackable_task.is_pending is False - client.clear_task(task_id=task_1.task_id) - client.clear_task(task_id=task_2.task_id) + rest_client.clear_task(task_id=task_1.task_id) + rest_client.clear_task(task_id=task_2.task_id) def test_progress_with_stomp(client_with_stomp: BlueapiClient): @@ -441,13 +458,13 @@ def on_event(event: AnyEvent): def test_get_current_state_of_environment(client: BlueapiClient): - assert client.get_environment().initialized + assert client.environment.initialized def test_delete_current_environment(client: BlueapiClient): - old_env = client.get_environment() + old_env = client.environment client.reload_environment() - new_env = client.get_environment() + new_env = client.environment assert new_env.initialized assert new_env.environment_id != old_env.environment_id assert new_env.error_message is None diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d13ccce80d..98aad78718 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -8,8 +8,16 @@ JsonObjectSpanExporter, asserting_span_exporter, ) - -from blueapi.client.client import BlueapiClient +from pydantic import HttpUrl + +from blueapi.client.client import ( + BlueapiClient, + DeviceCache, + DeviceRef, + MissingInstrumentSessionError, + Plan, + PlanCache, +) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError @@ -20,6 +28,7 @@ EnvironmentResponse, PlanModel, PlanResponse, + ProtocolInfo, TaskRequest, TaskResponse, TasksListResponse, @@ -35,6 +44,19 @@ ] ) PLAN = PlanModel(name="foo") +FULL_PLAN = PlanModel( + name="foobar", + description="Description of plan foobar", + schema={ + "title": "foobar", + "description": "Model description of plan foobar", + "properties": { + "one": {}, + "two": {}, + }, + "required": ["one"], + }, +) DEVICES = DeviceResponse( devices=[ DeviceModel(name="foo", protocols=[]), @@ -72,9 +94,9 @@ def mock_rest() -> BlueapiRestClient: mock = Mock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS - mock.get_plan.return_value = PLAN + mock.get_plan.side_effect = lambda n: {p.name: p for p in PLANS.plans}[n] mock.get_devices.return_value = DEVICES - mock.get_device.return_value = DEVICE + mock.get_device.side_effect = lambda n: {d.name: d for d in DEVICES.devices}[n] mock.get_state.return_value = WorkerState.IDLE mock.get_task.return_value = TASK mock.get_all_tasks.return_value = TASKS @@ -105,114 +127,62 @@ def client_with_events(mock_rest: Mock, mock_events: MagicMock): return BlueapiClient(rest=mock_rest, events=mock_events) +def test_client_from_config(): + bc = BlueapiClient.from_config_file( + "tests/unit_tests/valid_example_config/client.yaml" + ) + assert bc._rest._config.url == HttpUrl("http://example.com:8082") + + def test_get_plans(client: BlueapiClient): - assert client.get_plans() == PLANS + assert PlanResponse(plans=[p.model for p in client.plans]) == PLANS def test_get_plan(client: BlueapiClient): - assert client.get_plan("foo") == PLAN + assert client.plans.foo.model == PLAN + assert client.plans["foo"].model == PLAN def test_get_nonexistant_plan( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_plan.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_plan("baz") + with pytest.raises(AttributeError): + _ = client.plans.fizz_buzz.model def test_get_devices(client: BlueapiClient): - assert client.get_devices() == DEVICES + assert DeviceResponse(devices=[d.model for d in client.devices]) == DEVICES def test_get_device(client: BlueapiClient): - assert client.get_device("foo") == DEVICE - - -def test_get_nonexistant_device( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_device.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_device("baz") - - -def test_get_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE - - -def test_get_task(client: BlueapiClient): - assert client.get_task("foo") == TASK - - -def test_get_nonexistent_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_task("baz") - - -def test_get_task_with_empty_id(client: BlueapiClient): - with pytest.raises(AssertionError) as exc: - client.get_task("") - assert str(exc) == "Task ID not provided!" + assert client.devices.foo.model == DEVICE -def test_get_all_tasks( +def test_get_nonexistent_device( client: BlueapiClient, ): - assert client.get_all_tasks() == TASKS + with pytest.raises(AttributeError): + _ = client.devices.baz -def test_create_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.create_task.assert_called_once_with( - TaskRequest(name="foo", instrument_session="cm12345-1") +def test_get_child_device(mock_rest: Mock, client: BlueapiClient): + mock_rest.get_device.side_effect = ( + lambda name: DeviceModel(name="foo.x", protocols=[ProtocolInfo(name="One")]) + if name == "foo.x" + else None ) + foo = client.devices.foo + assert foo == "foo" + x = client.devices.foo.x + assert x == "foo.x" -def test_create_task_does_not_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.update_worker_task.assert_not_called() - - -def test_clear_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.clear_task(task_id="foo") - mock_rest.clear_task.assert_called_once_with("foo") +def test_get_state(client: BlueapiClient): + assert client.state == WorkerState.IDLE def test_get_active_task(client: BlueapiClient): - assert client.get_active_task() == ACTIVE_TASK - - -def test_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.start_task(task=WorkerTask(task_id="bar")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) - - -def test_start_nonexistant_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.update_worker_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.start_task(task=WorkerTask(task_id="bar")) + assert client.active_task == ACTIVE_TASK def test_create_and_start_task_calls_both_creating_and_starting_endpoints( @@ -266,7 +236,7 @@ def test_create_and_start_task_fails_if_task_start_fails( def test_get_environment(client: BlueapiClient): - assert client.get_environment() == ENV + assert client.environment == ENV def test_reload_environment( @@ -521,76 +491,40 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_on_event.assert_called_once_with(COMPLETE_EVENT) +def test_get_oidc_config(client, mock_rest): + assert client.oidc_config == mock_rest.get_oidc_config() + + def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plans"): - client.get_plans() + with asserting_span_exporter(exporter, "plans"): + _ = client.plans def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plan", "name"): - client.get_plan("foo") + with asserting_span_exporter(exporter, "plans"): + _ = client.plans.foo def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_devices"): - client.get_devices() + with asserting_span_exporter(exporter, "devices"): + _ = client.devices def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_device", "name"): - client.get_device("foo") - - -def test_get_state_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_state"): - client.get_state() - + with asserting_span_exporter(exporter, "devices"): + _ = client.devices.foo -def test_get_task_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_task", "task_id"): - client.get_task("foo") - -def test_get_all_tasks_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, -): - with asserting_span_exporter(exporter, "get_all_tasks"): - client.get_all_tasks() - - -def test_create_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - - -def test_clear_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "clear_task"): - client.clear_task(task_id="foo") +def test_get_state_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "state"): + _ = client.state def test_get_active_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_active_task"): - client.get_active_task() - - -def test_start_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "start_task", "task"): - client.start_task(task=WorkerTask(task_id="bar")) + with asserting_span_exporter(exporter, "active_task"): + _ = client.active_task def test_create_and_start_task_span_ok( @@ -609,8 +543,8 @@ def test_create_and_start_task_span_ok( def test_get_environment_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_environment"): - client.get_environment() + with asserting_span_exporter(exporter, "environment"): + _ = client.environment def test_reload_environment_span_ok( @@ -668,3 +602,250 @@ def test_cannot_run_task_span_ok( ): with asserting_span_exporter(exporter, "grun_task"): client.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + +def test_instrument_session_required(client): + with pytest.raises(MissingInstrumentSessionError): + _ = client.instrument_session + + +def test_setting_instrument_session(client): + # This looks like a completely pointless test but instrument_session is a + # property with some logic so it's not purely to get coverage up + client.instrument_session = "cm12345-4" + assert client.instrument_session == "cm12345-4" + + +def test_fluent_instrument_session_setter(client): + client2 = client.with_instrument_session("cm12345-3") + assert client is client2 + assert client.instrument_session == "cm12345-3" + + +def test_plan_cache_ignores_underscores(client): + cache = PlanCache(client, [PlanModel(name="_ignored"), PlanModel(name="used")]) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + +def test_plan_cache_repr(client): + assert repr(client.plans) == "PlanCache(2 plans)" + + +def test_device_cache_ignores_underscores(): + rest = Mock() + rest.get_devices.return_value = DeviceResponse( + devices=[ + DeviceModel(name="_ignored", protocols=[]), + ] + ) + cache = DeviceCache(rest) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + rest.get_devices.reset_mock() + with pytest.raises(AttributeError, match="_anything"): + _ = cache._anything + rest.get_device.assert_not_called() + + +def test_devices_are_cached(mock_rest): + cache = DeviceCache(mock_rest) + _ = cache.foo + mock_rest.get_device.assert_not_called() + _ = cache["foo"] + mock_rest.get_device.assert_not_called() + + +def test_device_cache_repr(client): + assert repr(client.devices) == "DeviceCache(2 devices)" + + +def test_device_repr(): + cache = Mock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + assert repr(dev) == "Device(foo)" + + +def test_device_ignores_underscores(): + cache = MagicMock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + with pytest.raises(AttributeError, match="_underscore"): + _ = dev._underscore + cache.__getitem__.assert_not_called() + + +def test_plan_help_text(client): + plan = Plan("foo", PlanModel(name="foo", description="help for foo"), client) + assert plan.help_text == "help for foo" + + +def test_plan_fallback_help_text(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + client, + ) + assert plan.help_text == "Plan foo(one, two=None)" + + +def test_plan_properties(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + client, + ) + + assert plan.properties == {"one", "two"} + assert plan.required == ["one"] + + +def test_plan_empty_fallback_help_text(client): + plan = Plan( + "foo", PlanModel(name="foo", schema={"properties": {}, "required": []}), client + ) + assert plan.help_text == "Plan foo()" + + +p = pytest.param + + +@pytest.mark.parametrize( + "args,kwargs,params", + [ + p((1,), {}, {"one": 1}, id="required_as_positional"), + p((), {"one": 7}, {"one": 7}, id="required_as_keyword"), + p((1,), {"two": 23}, {"one": 1, "two": 23}, id="all_as_mixed_args_kwargs"), + p((1, 2), {}, {"one": 1, "two": 2}, id="all_as_positional"), + p((), {"one": 21, "two": 42}, {"one": 21, "two": 42}, id="all_as_keyword"), + ], +) +def test_plan_param_mapping(args, kwargs, params): + client = Mock() + client.instrument_session = "cm12345-1" + plan = Plan( + FULL_PLAN.name, + FULL_PLAN, + client, + ) + + plan(*args, **kwargs) + client.run_task.assert_called_once_with( + TaskRequest(name="foobar", instrument_session="cm12345-1", params=params) + ) + + +@pytest.mark.parametrize( + "args,kwargs,msg", + [ + p((), {}, r"Missing argument\(s\) for \{'one'\}", id="missing_required"), + p((1,), {"one": 7}, "multiple values for one", id="duplicate_required"), + p((1, 2), {"two": 23}, "multiple values for two", id="duplicate_optional"), + p((1, 2, 3), {}, "too many arguments", id="too_many_args"), + p( + (), + {"unknown_key": 42}, + r"got unexpected arguments: \{'unknown_key'\}", + id="unknown_arg", + ), + ], +) +def test_plan_invalid_param_mapping(args, kwargs, msg): + client = Mock() + client.instrument_session = "cm12345-1" + plan = Plan( + FULL_PLAN.name, + FULL_PLAN, + client, + ) + + with pytest.raises(TypeError, match=msg): + plan(*args, **kwargs) + client.run_task.assert_not_called() + + +def test_adding_removing_callback(client): + def callback(*a, **kw): + pass + + cb_id = client.add_callback(callback) + assert len(client.callbacks) == 1 + client.remove_callback(cb_id) + assert len(client.callbacks) == 0 + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="foo", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="foo"), + DataEvent(name="start", doc={}, task_id="0000-1111"), + ], +) +def test_client_callbacks( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + callback = Mock() + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert callback.mock_calls == [call(test_event), call(COMPLETE_EVENT)] + + +def test_client_callback_failures( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + failing_callback = Mock(side_effect=ValueError("Broken callback")) + callback = Mock() + client_with_events.add_callback(failing_callback) + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + evt = DataEvent(name="start", doc={}, task_id="foo") + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(evt, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert failing_callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] + assert callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index c8fce9d101..2ddcdd3800 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -45,7 +45,7 @@ def rest_with_auth(oidc_config: OIDCConfig, tmp_path) -> BlueapiRestClient: (500, BlueskyRemoteControlError), ], ) -@patch("blueapi.client.rest.requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_rest_error_code( mock_request: Mock, rest: BlueapiRestClient, diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe80..7210ec2bbc 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -110,7 +110,7 @@ def test_runs_with_umask_002( mock_umask.assert_called_once_with(0o002) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_connection_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -120,7 +120,7 @@ def test_connection_error_caught_by_wrapper_func( assert result.output == "Error: Failed to establish connection to blueapi server.\n" -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_authentication_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -133,7 +133,7 @@ def test_authentication_error_caught_by_wrapper_func( ) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_remote_error_raised_by_wrapper_func(mock_requests: Mock, runner: CliRunner): mock_requests.side_effect = BlueskyRemoteControlError("Response [450]") @@ -198,15 +198,15 @@ def test_invalid_config_path_handling(runner: CliRunner): assert result.exit_code == 1 -@patch("blueapi.cli.cli.BlueapiClient.get_plans") +@patch("blueapi.cli.cli.BlueapiClient.plans") @patch("blueapi.cli.cli.OutputFormat.FULL.display") def test_options_via_env(mock_display, mock_plans, runner: CliRunner): result = runner.invoke( main, args=["controller", "plans"], env={"BLUEAPI_CONTROLLER_OUTPUT": "full"} ) - mock_plans.assert_called_once_with() - mock_display.assert_called_once_with(mock_plans.return_value) + mock_plans.__iter__.assert_called_once_with() + mock_display.assert_called_once_with(PlanResponse(plans=list(mock_plans))) assert result.exit_code == 0 @@ -493,9 +493,7 @@ def test_valid_stomp_config_for_listener( @responses.activate -def test_get_env( - runner: CliRunner, -): +def test_get_env(runner: CliRunner): environment_id = uuid.uuid4() responses.add( responses.GET, @@ -514,6 +512,17 @@ def test_get_env( ) +@responses.activate +def test_get_state(runner: CliRunner): + responses.add( + responses.GET, "http://localhost:8000/worker/state", json="IDLE", status=200 + ) + state = runner.invoke(main, ["controller", "state"]) + print(state.stderr) + assert state.exit_code == 0 + assert state.output == "IDLE\n" + + @responses.activate(assert_all_requests_are_fired=True) @patch("blueapi.client.client.time.sleep", return_value=None) def test_reset_env_client_behavior(