Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: CI

on:
push:
pull_request:

jobs:
Ubuntu:
Expand Down
15 changes: 14 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,24 @@ update-deps:
check: lint test

.PHONY: lint
lint:
lint: lint-black lint-isort lint-pyflakes lint-mypy

.PHONY: lint-black
lint-black:
pipenv run black --check .

.PHONY: lint-isort
lint-isort:
pipenv run isort --check .

.PHONY: lint-pyflakes
lint-pyflakes:
pipenv run pyflakes .

.PHONY: lint-mypy
lint-mypy:
pipenv run mypy enapter

.PHONY: test
test: run-unit-tests run-integration-tests

Expand Down
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ black = "*"
docker = "*"
faker = "*"
isort = "*"
mypy = "*"
pyflakes = "*"
pytest = "*"
pytest-asyncio = "*"
Expand Down
2 changes: 1 addition & 1 deletion enapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.1"
__version__ = "0.11.3"

from . import async_, log, mdns, mqtt, types, vucm

Expand Down
7 changes: 5 additions & 2 deletions enapter/async_/generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import contextlib
import functools
from typing import AsyncContextManager, AsyncGenerator, Callable


def generator(func):
def generator(
func: Callable[..., AsyncGenerator],
) -> Callable[..., AsyncContextManager[AsyncGenerator]]:
@functools.wraps(func)
@contextlib.asynccontextmanager
async def wrapper(*args, **kwargs):
async def wrapper(*args, **kwargs) -> AsyncGenerator[AsyncGenerator, None]:
gen = func(*args, **kwargs)
try:
yield gen
Expand Down
17 changes: 9 additions & 8 deletions enapter/async_/routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@

class Routine(abc.ABC):
@abc.abstractmethod
async def _run(self):
async def _run(self) -> None:
raise NotImplementedError # pragma: no cover

async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, *_):
async def __aexit__(self, *_) -> None:
await self.stop()

def task(self):
def task(self) -> asyncio.Task:
return self._task

async def start(self, cancel_parent_task_on_exception=True):
async def start(self, cancel_parent_task_on_exception: bool = True) -> None:
self._started = asyncio.Event()
self._stack = contextlib.AsyncExitStack()

Expand All @@ -43,26 +43,27 @@ async def start(self, cancel_parent_task_on_exception=True):
if self._task in done:
self._task.result()

async def stop(self):
async def stop(self) -> None:
self.cancel()
await self.join()

def cancel(self):
def cancel(self) -> None:
self._task.cancel()

async def join(self):
async def join(self) -> None:
if self._task.done():
self._task.result()
else:
await self._task

async def __run(self):
async def __run(self) -> None:
try:
await self._run()
except asyncio.CancelledError:
pass
except:
if self._started.is_set() and self._cancel_parent_task_on_exception:
assert self._parent_task is not None
self._parent_task.cancel()
raise
finally:
Expand Down
13 changes: 10 additions & 3 deletions enapter/log/json_formatter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import datetime
import logging
from typing import Any, Dict

import json_log_formatter
import json_log_formatter # type: ignore


class JSONFormatter(json_log_formatter.JSONFormatter):
def json_record(self, message, extra, record):
def json_record(
self,
message: str,
extra: Dict[str, Any],
record: logging.LogRecord,
) -> Dict[str, Any]:
try:
del extra["taskName"]
except KeyError:
Expand All @@ -27,5 +34,5 @@ def json_record(self, message, extra, record):

return json_record

def mutate_json_record(self, json_record):
def mutate_json_record(self, json_record: Dict[str, Any]) -> Dict[str, Any]:
return json_record
12 changes: 6 additions & 6 deletions enapter/mdns/resolver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging

import dns.asyncresolver
import dns.asyncresolver # type: ignore

LOGGER = logging.getLogger(__name__)


class Resolver:
def __init__(self):
def __init__(self) -> None:
self._logger = LOGGER
self._dns_resolver = self._new_dns_resolver()
self._mdns_resolver = self._new_mdns_resolver()

async def resolve(self, host):
async def resolve(self, host: str) -> str:
# TODO: Resolve concurrently.
try:
ip = await self._resolve(self._dns_resolver, host)
Expand All @@ -25,17 +25,17 @@ async def resolve(self, host):
self._logger.info("%r resolved using mDNS: %r", host, ip)
return ip

async def _resolve(self, resolver, host):
async def _resolve(self, resolver: dns.asyncresolver.Resolver, host: str) -> str:
answer = await resolver.resolve(host, "A")
if not answer:
raise ValueError(f"empty answer received: {host}")

return answer[0].address

def _new_dns_resolver(self):
def _new_dns_resolver(self) -> dns.asyncresolver.Resolver:
return dns.asyncresolver.Resolver(configure=True)

def _new_mdns_resolver(self):
def _new_mdns_resolver(self) -> dns.asyncresolver.Resolver:
r = dns.asyncresolver.Resolver(configure=False)
r.nameservers = ["224.0.0.251"]
r.port = 5353
Expand Down
5 changes: 3 additions & 2 deletions enapter/mqtt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from . import api
from .client import Client
from .config import Config
from .config import Config, TLSConfig

__all__ = [
"api",
"Client",
"Config",
"TLSConfig",
"api",
]
18 changes: 12 additions & 6 deletions enapter/mqtt/api/command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import json
from typing import Any, Dict, Optional, Union


class CommandState(enum.Enum):
Expand All @@ -9,24 +10,29 @@ class CommandState(enum.Enum):

class CommandRequest:
@classmethod
def unmarshal_json(cls, data):
def unmarshal_json(cls, data: Union[str, bytes]) -> "CommandRequest":
req = json.loads(data)
return cls(id_=req["id"], name=req["name"], args=req.get("arguments"))

def __init__(self, id_, name, args=None):
def __init__(self, id_: str, name: str, args: Optional[Dict[str, Any]] = None):
self.id = id_
self.name = name

if args is None:
args = {}
self.args = args

def new_response(self, *args, **kwargs):
def new_response(self, *args, **kwargs) -> "CommandResponse":
return CommandResponse(self.id, *args, **kwargs)


class CommandResponse:
def __init__(self, id_, state, payload=None):
def __init__(
self,
id_: str,
state: Union[str, CommandState],
payload: Optional[Union[Dict[str, Any], str]] = None,
) -> None:
self.id = id_

if not isinstance(state, CommandState):
Expand All @@ -37,8 +43,8 @@ def __init__(self, id_, state, payload=None):
payload = {"message": payload}
self.payload = payload

def json(self):
json_object = {"id": self.id, "state": self.state.value}
def json(self) -> Dict[str, Any]:
json_object: Dict[str, Any] = {"id": self.id, "state": self.state.value}
if self.payload is not None:
json_object["payload"] = self.payload

Expand Down
46 changes: 28 additions & 18 deletions enapter/mqtt/api/device_channel.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,82 @@
import json
import logging
import time
from typing import Any, AsyncContextManager, AsyncGenerator, Dict

import aiomqtt # type: ignore

import enapter

from .command import CommandRequest
from ..client import Client
from .command import CommandRequest, CommandResponse
from .log_severity import LogSeverity

LOGGER = logging.getLogger(__name__)


class DeviceChannel:
def __init__(self, client, hardware_id, channel_id):
def __init__(self, client: Client, hardware_id: str, channel_id: str) -> None:
self._client = client
self._logger = self._new_logger(hardware_id, channel_id)
self._hardware_id = hardware_id
self._channel_id = channel_id

@property
def hardware_id(self):
def hardware_id(self) -> str:
return self._hardware_id

@property
def channel_id(self):
def channel_id(self) -> str:
return self._channel_id

@staticmethod
def _new_logger(hardware_id, channel_id):
def _new_logger(hardware_id, channel_id) -> logging.LoggerAdapter:
extra = {"hardware_id": hardware_id, "channel_id": channel_id}
return logging.LoggerAdapter(LOGGER, extra=extra)

@enapter.async_.generator
async def subscribe_to_command_requests(self):
async def subscribe_to_command_requests(
self,
) -> AsyncGenerator[CommandRequest, None]:
async with self._subscribe("v1/command/requests") as messages:
async for msg in messages:
assert isinstance(msg.payload, str) or isinstance(msg.payload, bytes)
yield CommandRequest.unmarshal_json(msg.payload)

async def publish_command_response(self, resp):
async def publish_command_response(self, resp: CommandResponse) -> None:
await self._publish_json("v1/command/responses", resp.json())

async def publish_telemetry(self, telemetry, **kwargs):
async def publish_telemetry(self, telemetry: Dict[str, Any], **kwargs) -> None:
await self._publish_json("v1/telemetry", telemetry, **kwargs)

async def publish_properties(self, properties, **kwargs):
async def publish_properties(self, properties: Dict[str, Any], **kwargs) -> None:
await self._publish_json("v1/register", properties, **kwargs)

async def publish_logs(self, msg, severity, persist=False, **kwargs):
async def publish_logs(
self, msg: str, severity: LogSeverity, persist: bool = False, **kwargs
) -> None:
logs = {
"message": msg,
"severity": severity.value,
"persist": persist,
}
if persist:
logs["persist"] = True

await self._publish_json("v3/logs", logs, **kwargs)

def _subscribe(self, path):
def _subscribe(
self, path: str
) -> AsyncContextManager[AsyncGenerator[aiomqtt.Message, None]]:
topic = f"v1/to/{self._hardware_id}/{self._channel_id}/{path}"
return self._client.subscribe(topic)

async def _publish_json(self, path, json_object, **kwargs):
async def _publish_json(
self, path: str, json_object: Dict[str, Any], **kwargs
) -> None:
if "timestamp" not in json_object:
json_object["timestamp"] = int(time.time())

payload = json.dumps(json_object)

await self._publish(path, payload, **kwargs)

async def _publish(self, path, payload, **kwargs):
async def _publish(self, path: str, payload: str, **kwargs) -> None:
topic = f"v1/from/{self._hardware_id}/{self._channel_id}/{path}"
try:
await self._client.publish(topic, payload, **kwargs)
Expand Down
Loading