diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 90726d1..824e20b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,6 @@ name: CI on: push: - pull_request: jobs: Ubuntu: diff --git a/Makefile b/Makefile index 4ce8734..34a0591 100644 --- a/Makefile +++ b/Makefile @@ -13,11 +13,24 @@ update-deps: check: lint test .PHONY: lint -lint: +lint: lint-black lint-isort lint-pyflakes lint-mypy + +.PHONY: lint-black +lint-black: pipenv run black --check . + +.PHONY: lint-isort +lint-isort: pipenv run isort --check . + +.PHONY: lint-pyflakes +lint-pyflakes: pipenv run pyflakes . +.PHONY: lint-mypy +lint-mypy: + pipenv run mypy enapter + .PHONY: test test: run-unit-tests run-integration-tests diff --git a/Pipfile b/Pipfile index 009d0a2..0145ddc 100644 --- a/Pipfile +++ b/Pipfile @@ -11,6 +11,7 @@ black = "*" docker = "*" faker = "*" isort = "*" +mypy = "*" pyflakes = "*" pytest = "*" pytest-asyncio = "*" diff --git a/enapter/__init__.py b/enapter/__init__.py index a7c2db4..7e6c304 100644 --- a/enapter/__init__.py +++ b/enapter/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.11.1" +__version__ = "0.11.3" from . import async_, log, mdns, mqtt, types, vucm diff --git a/enapter/async_/generator.py b/enapter/async_/generator.py index b9a8ff1..a8e45a6 100644 --- a/enapter/async_/generator.py +++ b/enapter/async_/generator.py @@ -1,11 +1,14 @@ import contextlib import functools +from typing import AsyncContextManager, AsyncGenerator, Callable -def generator(func): +def generator( + func: Callable[..., AsyncGenerator], +) -> Callable[..., AsyncContextManager[AsyncGenerator]]: @functools.wraps(func) @contextlib.asynccontextmanager - async def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs) -> AsyncGenerator[AsyncGenerator, None]: gen = func(*args, **kwargs) try: yield gen diff --git a/enapter/async_/routine.py b/enapter/async_/routine.py index 9c13607..4540596 100644 --- a/enapter/async_/routine.py +++ b/enapter/async_/routine.py @@ -5,20 +5,20 @@ class Routine(abc.ABC): @abc.abstractmethod - async def _run(self): + async def _run(self) -> None: raise NotImplementedError # pragma: no cover async def __aenter__(self): await self.start() return self - async def __aexit__(self, *_): + async def __aexit__(self, *_) -> None: await self.stop() - def task(self): + def task(self) -> asyncio.Task: return self._task - async def start(self, cancel_parent_task_on_exception=True): + async def start(self, cancel_parent_task_on_exception: bool = True) -> None: self._started = asyncio.Event() self._stack = contextlib.AsyncExitStack() @@ -43,26 +43,27 @@ async def start(self, cancel_parent_task_on_exception=True): if self._task in done: self._task.result() - async def stop(self): + async def stop(self) -> None: self.cancel() await self.join() - def cancel(self): + def cancel(self) -> None: self._task.cancel() - async def join(self): + async def join(self) -> None: if self._task.done(): self._task.result() else: await self._task - async def __run(self): + async def __run(self) -> None: try: await self._run() except asyncio.CancelledError: pass except: if self._started.is_set() and self._cancel_parent_task_on_exception: + assert self._parent_task is not None self._parent_task.cancel() raise finally: diff --git a/enapter/log/json_formatter.py b/enapter/log/json_formatter.py index ee8ce2d..46b351a 100644 --- a/enapter/log/json_formatter.py +++ b/enapter/log/json_formatter.py @@ -1,10 +1,17 @@ import datetime +import logging +from typing import Any, Dict -import json_log_formatter +import json_log_formatter # type: ignore class JSONFormatter(json_log_formatter.JSONFormatter): - def json_record(self, message, extra, record): + def json_record( + self, + message: str, + extra: Dict[str, Any], + record: logging.LogRecord, + ) -> Dict[str, Any]: try: del extra["taskName"] except KeyError: @@ -27,5 +34,5 @@ def json_record(self, message, extra, record): return json_record - def mutate_json_record(self, json_record): + def mutate_json_record(self, json_record: Dict[str, Any]) -> Dict[str, Any]: return json_record diff --git a/enapter/mdns/resolver.py b/enapter/mdns/resolver.py index a314e25..30e3488 100644 --- a/enapter/mdns/resolver.py +++ b/enapter/mdns/resolver.py @@ -1,17 +1,17 @@ import logging -import dns.asyncresolver +import dns.asyncresolver # type: ignore LOGGER = logging.getLogger(__name__) class Resolver: - def __init__(self): + def __init__(self) -> None: self._logger = LOGGER self._dns_resolver = self._new_dns_resolver() self._mdns_resolver = self._new_mdns_resolver() - async def resolve(self, host): + async def resolve(self, host: str) -> str: # TODO: Resolve concurrently. try: ip = await self._resolve(self._dns_resolver, host) @@ -25,17 +25,17 @@ async def resolve(self, host): self._logger.info("%r resolved using mDNS: %r", host, ip) return ip - async def _resolve(self, resolver, host): + async def _resolve(self, resolver: dns.asyncresolver.Resolver, host: str) -> str: answer = await resolver.resolve(host, "A") if not answer: raise ValueError(f"empty answer received: {host}") return answer[0].address - def _new_dns_resolver(self): + def _new_dns_resolver(self) -> dns.asyncresolver.Resolver: return dns.asyncresolver.Resolver(configure=True) - def _new_mdns_resolver(self): + def _new_mdns_resolver(self) -> dns.asyncresolver.Resolver: r = dns.asyncresolver.Resolver(configure=False) r.nameservers = ["224.0.0.251"] r.port = 5353 diff --git a/enapter/mqtt/__init__.py b/enapter/mqtt/__init__.py index cd7e9f9..878a858 100644 --- a/enapter/mqtt/__init__.py +++ b/enapter/mqtt/__init__.py @@ -1,9 +1,10 @@ from . import api from .client import Client -from .config import Config +from .config import Config, TLSConfig __all__ = [ - "api", "Client", "Config", + "TLSConfig", + "api", ] diff --git a/enapter/mqtt/api/command.py b/enapter/mqtt/api/command.py index 7aad360..f9a7de2 100644 --- a/enapter/mqtt/api/command.py +++ b/enapter/mqtt/api/command.py @@ -1,5 +1,6 @@ import enum import json +from typing import Any, Dict, Optional, Union class CommandState(enum.Enum): @@ -9,11 +10,11 @@ class CommandState(enum.Enum): class CommandRequest: @classmethod - def unmarshal_json(cls, data): + def unmarshal_json(cls, data: Union[str, bytes]) -> "CommandRequest": req = json.loads(data) return cls(id_=req["id"], name=req["name"], args=req.get("arguments")) - def __init__(self, id_, name, args=None): + def __init__(self, id_: str, name: str, args: Optional[Dict[str, Any]] = None): self.id = id_ self.name = name @@ -21,12 +22,17 @@ def __init__(self, id_, name, args=None): args = {} self.args = args - def new_response(self, *args, **kwargs): + def new_response(self, *args, **kwargs) -> "CommandResponse": return CommandResponse(self.id, *args, **kwargs) class CommandResponse: - def __init__(self, id_, state, payload=None): + def __init__( + self, + id_: str, + state: Union[str, CommandState], + payload: Optional[Union[Dict[str, Any], str]] = None, + ) -> None: self.id = id_ if not isinstance(state, CommandState): @@ -37,8 +43,8 @@ def __init__(self, id_, state, payload=None): payload = {"message": payload} self.payload = payload - def json(self): - json_object = {"id": self.id, "state": self.state.value} + def json(self) -> Dict[str, Any]: + json_object: Dict[str, Any] = {"id": self.id, "state": self.state.value} if self.payload is not None: json_object["payload"] = self.payload diff --git a/enapter/mqtt/api/device_channel.py b/enapter/mqtt/api/device_channel.py index a7a5a77..3c6fd27 100644 --- a/enapter/mqtt/api/device_channel.py +++ b/enapter/mqtt/api/device_channel.py @@ -1,72 +1,82 @@ import json import logging import time +from typing import Any, AsyncContextManager, AsyncGenerator, Dict + +import aiomqtt # type: ignore import enapter -from .command import CommandRequest +from ..client import Client +from .command import CommandRequest, CommandResponse +from .log_severity import LogSeverity LOGGER = logging.getLogger(__name__) class DeviceChannel: - def __init__(self, client, hardware_id, channel_id): + def __init__(self, client: Client, hardware_id: str, channel_id: str) -> None: self._client = client self._logger = self._new_logger(hardware_id, channel_id) self._hardware_id = hardware_id self._channel_id = channel_id @property - def hardware_id(self): + def hardware_id(self) -> str: return self._hardware_id @property - def channel_id(self): + def channel_id(self) -> str: return self._channel_id @staticmethod - def _new_logger(hardware_id, channel_id): + def _new_logger(hardware_id, channel_id) -> logging.LoggerAdapter: extra = {"hardware_id": hardware_id, "channel_id": channel_id} return logging.LoggerAdapter(LOGGER, extra=extra) @enapter.async_.generator - async def subscribe_to_command_requests(self): + async def subscribe_to_command_requests( + self, + ) -> AsyncGenerator[CommandRequest, None]: async with self._subscribe("v1/command/requests") as messages: async for msg in messages: + assert isinstance(msg.payload, str) or isinstance(msg.payload, bytes) yield CommandRequest.unmarshal_json(msg.payload) - async def publish_command_response(self, resp): + async def publish_command_response(self, resp: CommandResponse) -> None: await self._publish_json("v1/command/responses", resp.json()) - async def publish_telemetry(self, telemetry, **kwargs): + async def publish_telemetry(self, telemetry: Dict[str, Any], **kwargs) -> None: await self._publish_json("v1/telemetry", telemetry, **kwargs) - async def publish_properties(self, properties, **kwargs): + async def publish_properties(self, properties: Dict[str, Any], **kwargs) -> None: await self._publish_json("v1/register", properties, **kwargs) - async def publish_logs(self, msg, severity, persist=False, **kwargs): + async def publish_logs( + self, msg: str, severity: LogSeverity, persist: bool = False, **kwargs + ) -> None: logs = { "message": msg, "severity": severity.value, + "persist": persist, } - if persist: - logs["persist"] = True - await self._publish_json("v3/logs", logs, **kwargs) - def _subscribe(self, path): + def _subscribe( + self, path: str + ) -> AsyncContextManager[AsyncGenerator[aiomqtt.Message, None]]: topic = f"v1/to/{self._hardware_id}/{self._channel_id}/{path}" return self._client.subscribe(topic) - async def _publish_json(self, path, json_object, **kwargs): + async def _publish_json( + self, path: str, json_object: Dict[str, Any], **kwargs + ) -> None: if "timestamp" not in json_object: json_object["timestamp"] = int(time.time()) - payload = json.dumps(json_object) - await self._publish(path, payload, **kwargs) - async def _publish(self, path, payload, **kwargs): + async def _publish(self, path: str, payload: str, **kwargs) -> None: topic = f"v1/from/{self._hardware_id}/{self._channel_id}/{path}" try: await self._client.publish(topic, payload, **kwargs) diff --git a/enapter/mqtt/client.py b/enapter/mqtt/client.py index 165b17d..4fb65a0 100644 --- a/enapter/mqtt/client.py +++ b/enapter/mqtt/client.py @@ -1,145 +1,115 @@ import asyncio -import collections import contextlib import logging import ssl import tempfile -import aiomqtt +import aiomqtt # type: ignore import enapter LOGGER = logging.getLogger(__name__) +from typing import AsyncGenerator, Optional + +from .config import Config + class Client(enapter.async_.Routine): - def __init__(self, config): + def __init__(self, config: Config) -> None: self._logger = self._new_logger(config) self._config = config self._mdns_resolver = enapter.mdns.Resolver() self._tls_context = self._new_tls_context(config) - self._client = None - self._client_ready = asyncio.Event() - self._subscribers = collections.defaultdict(int) + self._publisher: Optional[aiomqtt.Client] = None + self._publisher_connected = asyncio.Event() @staticmethod - def _new_logger(config): + def _new_logger(config: Config) -> logging.LoggerAdapter: extra = {"host": config.host, "port": config.port} return logging.LoggerAdapter(LOGGER, extra=extra) - def config(self): + def config(self) -> Config: return self._config - async def publish(self, *args, **kwargs): - client = await self._wait_client() - await client.publish(*args, **kwargs) + async def publish(self, *args, **kwargs) -> None: + await self._publisher_connected.wait() + assert self._publisher is not None + await self._publisher.publish(*args, **kwargs) @enapter.async_.generator - async def subscribe(self, topic): + async def subscribe(self, *topics: str) -> AsyncGenerator[aiomqtt.Message, None]: while True: - client = await self._wait_client() - try: - async with client.messages() as messages: - async with self._subscribe(client, topic): - async for msg in messages: - if msg.topic.matches(topic): - yield msg - + async with self._connect() as subscriber: + for topic in topics: + await subscriber.subscribe(topic) + self._logger.info("subscriber [%s] connected", ",".join(topics)) + async for msg in subscriber.messages: + yield msg except aiomqtt.MqttError as e: self._logger.error(e) retry_interval = 5 await asyncio.sleep(retry_interval) - @contextlib.asynccontextmanager - async def _subscribe(self, client, topic): - first_subscriber = not self._subscribers[topic] - self._subscribers[topic] += 1 - try: - if first_subscriber: - await client.subscribe(topic) - yield - finally: - self._subscribers[topic] -= 1 - assert not self._subscribers[topic] < 0 - last_unsubscriber = not self._subscribers[topic] - if last_unsubscriber: - del self._subscribers[topic] - await client.unsubscribe(topic) - - async def _wait_client(self): - await self._client_ready.wait() - assert self._client_ready.is_set() - return self._client - - async def _run(self): + async def _run(self) -> None: self._logger.info("starting") - self._started.set() - while True: try: - async with self._connect() as client: - self._client = client - self._client_ready.set() - self._logger.info("client ready") - - # tracking disconnect - async with client.messages() as messages: - async for msg in messages: - pass + async with self._connect() as publisher: + self._logger.info("publisher connected") + self._publisher = publisher + self._publisher_connected.set() + async for msg in publisher.messages: + pass except aiomqtt.MqttError as e: self._logger.error(e) retry_interval = 5 await asyncio.sleep(retry_interval) finally: - self._client_ready.clear() - self._client = None - self._logger.info("client not ready") + self._publisher_connected.clear() + self._publisher = None + self._logger.info("publisher disconnected") @contextlib.asynccontextmanager - async def _connect(self): + async def _connect(self) -> AsyncGenerator[aiomqtt.Client, None]: host = await self._maybe_resolve_mdns(self._config.host) - - try: - async with aiomqtt.Client( - hostname=host, - port=self._config.port, - username=self._config.user, - password=self._config.password, - logger=self._logger, - tls_context=self._tls_context, - ) as client: - yield client - except asyncio.CancelledError: - # FIXME: A cancelled `aiomqtt.Client.connect` leaks resources. - raise + async with aiomqtt.Client( + hostname=host, + port=self._config.port, + username=self._config.user, + password=self._config.password, + logger=LOGGER, + tls_context=self._tls_context, + ) as client: + yield client @staticmethod - def _new_tls_context(config): - if not config.tls_enabled: + def _new_tls_context(config: Config) -> Optional[ssl.SSLContext]: + if config.tls is None: return None ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = False - ctx.load_verify_locations(None, None, config.tls_ca_cert) + ctx.load_verify_locations(None, None, config.tls.ca_cert) with contextlib.ExitStack() as stack: certfile = stack.enter_context(tempfile.NamedTemporaryFile()) - certfile.write(config.tls_cert.encode()) + certfile.write(config.tls.cert.encode()) certfile.flush() keyfile = stack.enter_context(tempfile.NamedTemporaryFile()) - keyfile.write(config.tls_secret_key.encode()) + keyfile.write(config.tls.secret_key.encode()) keyfile.flush() ctx.load_cert_chain(certfile.name, keyfile=keyfile.name) return ctx - async def _maybe_resolve_mdns(self, host): + async def _maybe_resolve_mdns(self, host: str) -> str: if not host.endswith(".local"): return host diff --git a/enapter/mqtt/config.py b/enapter/mqtt/config.py index 1692bad..d9eede1 100644 --- a/enapter/mqtt/config.py +++ b/enapter/mqtt/config.py @@ -1,48 +1,69 @@ import os +from typing import MutableMapping, Optional -class Config: +class TLSConfig: + @classmethod - def from_env(cls, prefix="ENAPTER_", env=os.environ): - def pem(value): - if value is None: - return value + def from_env( + cls, prefix: str = "ENAPTER_", env: MutableMapping[str, str] = os.environ + ) -> Optional["TLSConfig"]: + secret_key = env.get(prefix + "MQTT_TLS_SECRET_KEY") + cert = env.get(prefix + "MQTT_TLS_CERT") + ca_cert = env.get(prefix + "MQTT_TLS_CA_CERT") + + nothing_defined = {secret_key, cert, ca_cert} == {None} + if nothing_defined: + return None + + if secret_key is None: + raise KeyError(prefix + "MQTT_TLS_SECRET_KEY") + if cert is None: + raise KeyError(prefix + "MQTT_TLS_CERT") + if ca_cert is None: + raise KeyError(prefix + "MQTT_TLS_CA_CERT") + + def pem(value: str) -> str: return value.replace("\\n", "\n") + return cls(secret_key=pem(secret_key), cert=pem(cert), ca_cert=pem(ca_cert)) + + def __init__(self, secret_key: str, cert: str, ca_cert: str) -> None: + self.secret_key = secret_key + self.cert = cert + self.ca_cert = ca_cert + + +class Config: + @classmethod + def from_env( + cls, prefix: str = "ENAPTER_", env: MutableMapping[str, str] = os.environ + ) -> "Config": return cls( host=env[prefix + "MQTT_HOST"], port=int(env[prefix + "MQTT_PORT"]), user=env.get(prefix + "MQTT_USER", default=None), password=env.get(prefix + "MQTT_PASSWORD", default=None), - tls_secret_key=pem(env.get(prefix + "MQTT_TLS_SECRET_KEY", default=None)), - tls_cert=pem(env.get(prefix + "MQTT_TLS_CERT", default=None)), - tls_ca_cert=pem(env.get(prefix + "MQTT_TLS_CA_CERT", default=None)), + tls=TLSConfig.from_env(prefix=prefix, env=env), ) def __init__( self, - host, - port, - user=None, - password=None, - tls_secret_key=None, - tls_cert=None, - tls_ca_cert=None, - ): + host: str, + port: int, + user: Optional[str] = None, + password: Optional[str] = None, + tls: Optional[TLSConfig] = None, + ) -> None: self.host = host self.port = port self.user = user self.password = password + self.tls = tls - self.tls_secret_key = tls_secret_key - self.tls_cert = tls_cert - self.tls_ca_cert = tls_ca_cert - - self.tls_enabled = {tls_secret_key, tls_cert, tls_ca_cert} != {None} - - def __repr__(self): - return "mqtt.Config(host=%r, port=%r, tls_enabled=%r)" % ( + def __repr__(self) -> str: + return "mqtt.Config(host=%r, port=%r, tls=%r)" % ( self.host, self.port, - self.tls_enabled, + self.tls is not None, ) diff --git a/enapter/vucm/app.py b/enapter/vucm/app.py index b2f691d..99098e5 100644 --- a/enapter/vucm/app.py +++ b/enapter/vucm/app.py @@ -1,12 +1,22 @@ import asyncio +from typing import Optional, Protocol import enapter from .config import Config +from .device import Device from .ucm import UCM -async def run(device_factory, config_prefix=None): +class DeviceFactory(Protocol): + + def __call__(self, channel: enapter.mqtt.api.DeviceChannel, **kwargs) -> Device: + pass + + +async def run( + device_factory: DeviceFactory, config_prefix: Optional[str] = None +) -> None: enapter.log.configure(level=enapter.log.LEVEL or "info") config = Config.from_env(prefix=config_prefix) @@ -16,11 +26,11 @@ async def run(device_factory, config_prefix=None): class App(enapter.async_.Routine): - def __init__(self, config, device_factory): + def __init__(self, config: Config, device_factory: DeviceFactory) -> None: self._config = config self._device_factory = device_factory - async def _run(self): + async def _run(self) -> None: tasks = set() mqtt_client = await self._stack.enter_async_context( diff --git a/enapter/vucm/config.py b/enapter/vucm/config.py index bfb7e78..2726bb9 100644 --- a/enapter/vucm/config.py +++ b/enapter/vucm/config.py @@ -1,61 +1,72 @@ import base64 import json import os +from typing import MutableMapping, Optional import enapter class Config: @classmethod - def from_env(cls, prefix=None, env=os.environ): + def from_env( + cls, prefix: Optional[str] = None, env: MutableMapping[str, str] = os.environ + ) -> "Config": if prefix is None: prefix = "ENAPTER_VUCM_" try: - blob = os.environ[prefix + "BLOB"] + blob = env[prefix + "BLOB"] except KeyError: pass else: config = cls.from_blob(blob) try: - config.channel_id = os.environ[prefix + "CHANNEL_ID"] + config.channel_id = env[prefix + "CHANNEL_ID"] except KeyError: pass return config - hardware_id = os.environ[prefix + "HARDWARE_ID"] - channel_id = os.environ[prefix + "CHANNEL_ID"] + hardware_id = env[prefix + "HARDWARE_ID"] + channel_id = env[prefix + "CHANNEL_ID"] - mqtt_config = enapter.mqtt.Config.from_env(prefix=prefix, env=env) + mqtt = enapter.mqtt.Config.from_env(prefix=prefix, env=env) - start_ucm = os.environ.get(prefix + "START_UCM", "1") != "0" + start_ucm = env.get(prefix + "START_UCM", "1") != "0" return cls( hardware_id=hardware_id, channel_id=channel_id, - mqtt_config=mqtt_config, + mqtt=mqtt, start_ucm=start_ucm, ) @classmethod - def from_blob(cls, blob): + def from_blob(cls, blob: str) -> "Config": payload = json.loads(base64.b64decode(blob)) - mqtt_config = enapter.mqtt.Config( + mqtt = enapter.mqtt.Config( host=payload["mqtt_host"], port=int(payload["mqtt_port"]), - tls_ca_cert=payload["mqtt_ca"], - tls_cert=payload["mqtt_cert"], - tls_secret_key=payload["mqtt_private_key"], + tls=enapter.mqtt.TLSConfig( + ca_cert=payload["mqtt_ca"], + cert=payload["mqtt_cert"], + secret_key=payload["mqtt_private_key"], + ), ) return cls( hardware_id=payload["ucm_id"], channel_id=payload["channel_id"], - mqtt_config=mqtt_config, + mqtt=mqtt, ) - def __init__(self, hardware_id, channel_id, mqtt_config, start_ucm=True): + def __init__( + self, + hardware_id: str, + channel_id: str, + mqtt: enapter.mqtt.Config, + start_ucm: bool = True, + ) -> None: self.hardware_id = hardware_id self.channel_id = channel_id - self.mqtt = mqtt_config + self.mqtt = mqtt self.start_ucm = start_ucm diff --git a/enapter/vucm/device.py b/enapter/vucm/device.py index 5d102c3..2aa39bf 100644 --- a/enapter/vucm/device.py +++ b/enapter/vucm/device.py @@ -2,7 +2,7 @@ import concurrent import functools import traceback -from typing import Any, Callable, Coroutine, Optional, Set +from typing import Any, Callable, Coroutine, Dict, Optional, Set, Tuple import enapter @@ -34,7 +34,9 @@ def is_device_command(func: DeviceCommandFunc) -> bool: class Device(enapter.async_.Routine): - def __init__(self, channel, thread_pool_workers: int = 1) -> None: + def __init__( + self, channel: enapter.mqtt.api.DeviceChannel, thread_pool_workers: int = 1 + ) -> None: self.__channel = channel self.__tasks = {} @@ -57,7 +59,7 @@ def __init__(self, channel, thread_pool_workers: int = 1) -> None: self.alerts: Set[str] = set() async def send_telemetry( - self, telemetry: Optional[enapter.types.JSON] = None + self, telemetry: Optional[Dict[str, enapter.types.JSON]] = None ) -> None: if telemetry is None: telemetry = {} @@ -69,7 +71,7 @@ async def send_telemetry( await self.__channel.publish_telemetry(telemetry) async def send_properties( - self, properties: Optional[enapter.types.JSON] = None + self, properties: Optional[Dict[str, enapter.types.JSON]] = None ) -> None: if properties is None: properties = {} @@ -78,13 +80,13 @@ async def send_properties( await self.__channel.publish_properties(properties) - async def run_in_thread(self, func, *args, **kwargs): + async def run_in_thread(self, func, *args, **kwargs) -> Any: loop = asyncio.get_running_loop() return await loop.run_in_executor( self.__thread_pool_executor, functools.partial(func, *args, **kwargs) ) - async def _run(self): + async def _run(self) -> None: self._stack.enter_context(self.__thread_pool_executor) tasks = set() @@ -110,7 +112,7 @@ async def _run(self): task.cancel() self._stack.push_async_callback(self.__wait_task, task) - async def __wait_task(self, task): + async def __wait_task(self, task) -> None: try: await task except asyncio.CancelledError: @@ -122,14 +124,16 @@ async def __wait_task(self, task): pass raise - async def __process_command_requests(self): + async def __process_command_requests(self) -> None: async with self.__channel.subscribe_to_command_requests() as reqs: async for req in reqs: state, payload = await self.__execute_command(req) resp = req.new_response(state, payload) await self.__channel.publish_command_response(resp) - async def __execute_command(self, req): + async def __execute_command( + self, req + ) -> Tuple[enapter.mqtt.api.CommandState, enapter.types.JSON]: try: cmd = self.__commands[req.name] except KeyError: diff --git a/enapter/vucm/logger.py b/enapter/vucm/logger.py index 383928b..b1af474 100644 --- a/enapter/vucm/logger.py +++ b/enapter/vucm/logger.py @@ -6,32 +6,32 @@ class Logger: - def __init__(self, channel): + def __init__(self, channel) -> None: self._channel = channel self._logger = self._new_logger(channel.hardware_id, channel.channel_id) @staticmethod - def _new_logger(hardware_id, channel_id): + def _new_logger(hardware_id, channel_id) -> logging.LoggerAdapter: extra = {"hardware_id": hardware_id, "channel_id": channel_id} return logging.LoggerAdapter(LOGGER, extra=extra) - async def debug(self, msg: str, persist: bool = False): + async def debug(self, msg: str, persist: bool = False) -> None: self._logger.debug(msg) await self.log( msg, severity=enapter.mqtt.api.LogSeverity.DEBUG, persist=persist ) - async def info(self, msg: str, persist: bool = False): + async def info(self, msg: str, persist: bool = False) -> None: self._logger.info(msg) await self.log(msg, severity=enapter.mqtt.api.LogSeverity.INFO, persist=persist) - async def warning(self, msg: str, persist: bool = False): + async def warning(self, msg: str, persist: bool = False) -> None: self._logger.warning(msg) await self.log( msg, severity=enapter.mqtt.api.LogSeverity.WARNING, persist=persist ) - async def error(self, msg: str, persist: bool = False): + async def error(self, msg: str, persist: bool = False) -> None: self._logger.error(msg) await self.log( msg, severity=enapter.mqtt.api.LogSeverity.ERROR, persist=persist @@ -39,7 +39,7 @@ async def error(self, msg: str, persist: bool = False): async def log( self, msg: str, severity: enapter.mqtt.api.LogSeverity, persist: bool = False - ): + ) -> None: await self._channel.publish_logs(msg=msg, severity=severity, persist=persist) __call__ = log diff --git a/enapter/vucm/ucm.py b/enapter/vucm/ucm.py index 714c21b..9099e27 100644 --- a/enapter/vucm/ucm.py +++ b/enapter/vucm/ucm.py @@ -6,7 +6,7 @@ class UCM(Device): - def __init__(self, mqtt_client, hardware_id): + def __init__(self, mqtt_client, hardware_id) -> None: super().__init__( channel=enapter.mqtt.api.DeviceChannel( client=mqtt_client, hardware_id=hardware_id, channel_id="ucm" @@ -14,12 +14,12 @@ def __init__(self, mqtt_client, hardware_id): ) @device_command - async def reboot(self): + async def reboot(self) -> None: await asyncio.sleep(0) raise NotImplementedError @device_command - async def upload_lua_script(self, url, sha1, payload=None): + async def upload_lua_script(self, url, sha1, payload=None) -> None: await asyncio.sleep(0) raise NotImplementedError @@ -30,7 +30,7 @@ async def telemetry_publisher(self) -> None: await asyncio.sleep(1) @device_task - async def properties_publisher(self): + async def properties_publisher(self) -> None: while True: await self.send_properties({"virtual": True, "lua_api_ver": 1}) await asyncio.sleep(10) diff --git a/examples/mqtt/pub_sub.py b/examples/mqtt/pub_sub.py new file mode 100644 index 0000000..d6dbc5e --- /dev/null +++ b/examples/mqtt/pub_sub.py @@ -0,0 +1,28 @@ +import asyncio +import time + +import enapter + + +async def subscriber(client: enapter.mqtt.Client) -> None: + async with client.subscribe("/+") as messages: + async for msg in messages: + print(msg.topic, msg.payload.decode()) + + +async def publisher(client: enapter.mqtt.Client) -> None: + while True: + await client.publish(topic="/time", payload=str(time.time())) + await asyncio.sleep(1) + + +async def main() -> None: + config = enapter.mqtt.Config(host="127.0.0.1", port=1883) + async with enapter.mqtt.Client(config=config) as client: + async with asyncio.TaskGroup() as tg: + tg.create_task(subscriber(client)) + tg.create_task(publisher(client)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/setup.py b/setup.py index e6a74f7..f7fde5d 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def main(): author="Roman Novatorov", author_email="rnovatorov@enapter.com", install_requires=[ - "aiomqtt==1.0.*", + "aiomqtt==2.4.*", "dnspython==2.8.*", "json-log-formatter==1.1.*", ], diff --git a/tests/integration/test_mqtt.py b/tests/integration/test_mqtt.py index 7a554e0..ba1ea6c 100644 --- a/tests/integration/test_mqtt.py +++ b/tests/integration/test_mqtt.py @@ -92,6 +92,6 @@ async def _run(self): payload = str(int(time.time())) try: await self.enapter_mqtt_client.publish(self.topic, payload) - except aiomqtt.error.MqttError as e: + except aiomqtt.MqttError as e: print(f"failed to publish heartbit: {e}") await asyncio.sleep(self.interval)