diff --git a/Makefile b/Makefile index 30a72ee..6b96ebb 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ clean-dist: clean rm -rf dist/ lint: venv - $(VENV_ACTIVATE); python -m ruff check . + $(VENV_ACTIVATE); python -m ruff check . && python -m mypy format: venv $(VENV_ACTIVATE); python -m ruff format . && python -m ruff check . --fix diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..e4111fc --- /dev/null +++ b/mypy.ini @@ -0,0 +1,11 @@ +[mypy] +explicit_package_bases = true +files=plux/runtime/manager.py,tests/test_manager.py +ignore_missing_imports = False +follow_imports = silent +ignore_errors = False +disallow_untyped_defs = True +disallow_untyped_calls = True +disallow_any_generics = True +disallow_subclassing_any = True +warn_unused_ignores = True \ No newline at end of file diff --git a/plux/__init__.py b/plux/__init__.py index c665258..7809a4a 100644 --- a/plux/__init__.py +++ b/plux/__init__.py @@ -33,5 +33,5 @@ "PluginSpecResolver", "PluginType", "plugin", - "__version__" + "__version__", ] diff --git a/plux/core/plugin.py b/plux/core/plugin.py index 9affc21..5f6adec 100644 --- a/plux/core/plugin.py +++ b/plux/core/plugin.py @@ -27,7 +27,7 @@ class PluginDisabled(PluginException): reason: str - def __init__(self, namespace: str, name: str, reason: str = None): + def __init__(self, namespace: str, name: str, reason: str | None = None): message = f"plugin {namespace}:{name} is disabled" if reason: message = f"{message}, reason: {reason}" diff --git a/plux/runtime/manager.py b/plux/runtime/manager.py index 3b855ea..7939fb6 100644 --- a/plux/runtime/manager.py +++ b/plux/runtime/manager.py @@ -1,6 +1,8 @@ import logging import threading import typing as t +from collections.abc import Iterable +from importlib.metadata import EntryPoint from plux.core.plugin import ( Plugin, @@ -18,9 +20,10 @@ LOG = logging.getLogger(__name__) P = t.TypeVar("P", bound=Plugin) +PS = t.ParamSpec("PS") -def _call_safe(func: t.Callable, args: tuple, exception_message: str): +def _call_safe(func: t.Callable[PS, None], args: t.Any, exception_message: str) -> None: """ Call the given function with the given arguments, and if it fails, log the given exception_message. If logging.DEBUG is set for the logger, then we also log the traceback. An exception is made for any @@ -32,7 +35,7 @@ def _call_safe(func: t.Callable, args: tuple, exception_message: str): :return: whatever the func returns """ try: - return func(*args) + func(*args, **{}) except PluginException: # re-raise PluginExceptions, since they should be handled by the caller raise @@ -54,7 +57,7 @@ class PluginLifecycleNotifierMixin: listeners: list[PluginLifecycleListener] - def _fire_on_resolve_after(self, plugin_spec): + def _fire_on_resolve_after(self, plugin_spec: PluginSpec) -> None: for listener in self.listeners: _call_safe( listener.on_resolve_after, @@ -62,7 +65,9 @@ def _fire_on_resolve_after(self, plugin_spec): "error while calling on_resolve_after", ) - def _fire_on_resolve_exception(self, namespace, entrypoint, exception): + def _fire_on_resolve_exception( + self, namespace: str, entrypoint: EntryPoint, exception: Exception + ) -> None: for listener in self.listeners: _call_safe( listener.on_resolve_exception, @@ -70,7 +75,7 @@ def _fire_on_resolve_exception(self, namespace, entrypoint, exception): "error while calling on_resolve_exception", ) - def _fire_on_init_after(self, plugin_spec, plugin): + def _fire_on_init_after(self, plugin_spec: PluginSpec, plugin: P) -> None: for listener in self.listeners: _call_safe( listener.on_init_after, @@ -81,7 +86,7 @@ def _fire_on_init_after(self, plugin_spec, plugin): "error while calling on_init_after", ) - def _fire_on_init_exception(self, plugin_spec, exception): + def _fire_on_init_exception(self, plugin_spec: PluginSpec, exception: Exception) -> None: for listener in self.listeners: _call_safe( listener.on_init_exception, @@ -89,7 +94,9 @@ def _fire_on_init_exception(self, plugin_spec, exception): "error while calling on_init_exception", ) - def _fire_on_load_before(self, plugin_spec, plugin, load_args, load_kwargs): + def _fire_on_load_before( + self, plugin_spec: PluginSpec, plugin: P, load_args: t.Any, load_kwargs: t.Any + ) -> None: for listener in self.listeners: _call_safe( listener.on_load_before, @@ -97,7 +104,7 @@ def _fire_on_load_before(self, plugin_spec, plugin, load_args, load_kwargs): "error while calling on_load_before", ) - def _fire_on_load_after(self, plugin_spec, plugin, result): + def _fire_on_load_after(self, plugin_spec: PluginSpec, plugin: P | None, result: t.Any) -> None: for listener in self.listeners: _call_safe( listener.on_load_after, @@ -105,7 +112,9 @@ def _fire_on_load_after(self, plugin_spec, plugin, result): "error while calling on_load_after", ) - def _fire_on_load_exception(self, plugin_spec, plugin, exception): + def _fire_on_load_exception( + self, plugin_spec: PluginSpec, plugin: P | None, exception: Exception + ) -> None: for listener in self.listeners: _call_safe( listener.on_load_exception, @@ -123,20 +132,20 @@ class PluginContainer(t.Generic[P]): lock: threading.RLock plugin_spec: PluginSpec - plugin: P = None - load_value: t.Any | None = None + plugin: P | None = None + load_value: t.Any = None is_init: bool = False is_loaded: bool = False - init_error: Exception = None - load_error: Exception = None + init_error: Exception | None = None + load_error: Exception | None = None is_disabled: bool = False - disabled_reason = str = None + disabled_reason: str | None = None @property - def distribution(self) -> Distribution: + def distribution(self) -> Distribution | None: """ Uses metadata from importlib to resolve the distribution information for this plugin. @@ -160,7 +169,7 @@ class PluginManager(PluginLifecycleNotifierMixin, t.Generic[P]): namespace: str - load_args: list | tuple + load_args: list[t.Any] | tuple[t.Any, ...] load_kwargs: dict[str, t.Any] listeners: list[PluginLifecycleListener] filters: list[PluginFilter] @@ -168,11 +177,11 @@ class PluginManager(PluginLifecycleNotifierMixin, t.Generic[P]): def __init__( self, namespace: str, - load_args: list | tuple = None, - load_kwargs: dict = None, - listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] = None, - finder: PluginFinder = None, - filters: list[PluginFilter] = None, + load_args: list[t.Any] | tuple[t.Any, ...] | None = None, + load_kwargs: dict[str, t.Any] | None = None, + listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] | None = None, + finder: PluginFinder | None = None, + filters: list[PluginFilter] | None = None, ): """ Create a new ``PluginManager`` that can be used to load plugins. The simplest ``PluginManager`` only needs @@ -231,7 +240,7 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t self.load_kwargs = load_kwargs or dict() if listener: - if isinstance(listener, (list, set, tuple)): + if isinstance(listener, Iterable): self.listeners = list(listener) else: self.listeners = [listener] @@ -243,10 +252,10 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t self.finder = finder or MetadataPluginFinder(self.namespace, self._fire_on_resolve_exception) - self._plugin_index = None + self._plugin_index: dict[str, PluginContainer[P]] | None = None self._init_mutex = threading.RLock() - def add_listener(self, listener: PluginLifecycleListener): + def add_listener(self, listener: PluginLifecycleListener) -> None: """ Adds a lifecycle listener to the plugin manager. The listener will be notified of plugin lifecycle events. @@ -326,12 +335,12 @@ def load(self, name: str) -> P: if container.load_error: raise container.load_error - if not container.is_loaded: + if container.plugin is None or not container.is_loaded: raise PluginException("plugin did not load correctly", namespace=self.namespace, name=name) return container.plugin - def load_all(self, propagate_exceptions=False) -> list[P]: + def load_all(self, propagate_exceptions: bool = False) -> list[P]: """ Attempts to load all plugins found in the namespace and returns those that were loaded successfully. @@ -364,10 +373,10 @@ def load_all(self, propagate_exceptions=False) -> list[P]: :param propagate_exceptions: If True, re-raises any exceptions encountered during loading :return: A list of successfully loaded plugin instances """ - plugins = list() + plugins: list[P] = list() for name, container in self._plugins.items(): - if container.is_loaded: + if container.plugin is not None and container.is_loaded: plugins.append(container.plugin) continue @@ -552,7 +561,7 @@ def _require_plugin(self, name: str) -> PluginContainer[P]: return self._plugins[name] - def _load_plugin(self, container: PluginContainer) -> None: + def _load_plugin(self, container: PluginContainer[P]) -> None: """ Implements the core algorithm to load a plugin from a ``PluginSpec`` (contained in the ``PluginContainer``), and stores all relevant results, such as the Plugin instance, load result, or any errors into the passed @@ -602,6 +611,7 @@ def _load_plugin(self, container: PluginContainer) -> None: return plugin = container.plugin + assert plugin # Make MyPy happy - plugin should exist at this point if not plugin.should_load(): raise PluginDisabled( @@ -643,9 +653,9 @@ def _plugin_from_spec(self, plugin_spec: PluginSpec) -> P: if spec: factory = spec.factory - return factory() + return factory() # type: ignore[return-value] - def _init_plugin_index(self) -> dict[str, PluginContainer]: + def _init_plugin_index(self) -> dict[str, PluginContainer[P]]: """ Initializes the plugin index, which maps plugin names to plugin containers. This method will *resolve* plugins, meaning it loads the entry point object reference, thereby importing all its code. @@ -654,7 +664,7 @@ def _init_plugin_index(self) -> dict[str, PluginContainer]: """ return {plugin.name: plugin for plugin in self._import_plugins() if plugin} - def _import_plugins(self) -> t.Iterable[PluginContainer]: + def _import_plugins(self) -> t.Iterable[PluginContainer[P]]: """ Finds all ``PluginSpace`` instances in the namespace, creates a container for each spec, and yields them one by one. The plugin finder will typically load the entry point which involves importing the module it lives in. @@ -671,14 +681,14 @@ def _import_plugins(self) -> t.Iterable[PluginContainer]: yield self._create_container(spec) - def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer: + def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer[P]: """ Factory method to create a ``PluginContainer`` for the given ``PluginSpec``. :param plugin_spec: The ``PluginSpec`` to create a container for. :return: A new ``PluginContainer`` with the basic information of the plugin spec. """ - container = PluginContainer() + container = PluginContainer[P]() container.lock = threading.RLock() container.name = plugin_spec.name container.plugin_spec = plugin_spec diff --git a/pyproject.toml b/pyproject.toml index 42cab61..02aecff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dev = [ "setuptools", "pytest==8.4.1", "ruff==0.9.1", + "mypy", ] [tool.hatch.build.hooks.vcs] diff --git a/tests/test_manager.py b/tests/test_manager.py index 0ea1436..c00fe22 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,5 +1,5 @@ -from typing import Dict, List, Tuple -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import Mock import pytest @@ -8,13 +8,13 @@ class DummyPlugin(Plugin): - load_calls: List[Tuple[Tuple, Dict]] + load_calls: list[tuple[tuple[Any], dict[str, Any]]] def __init__(self) -> None: super().__init__() self.load_calls = list() - def load(self, *args, **kwargs): + def load(self, *args: Any, **kwargs: Any) -> None: self.load_calls.append((args, kwargs)) @@ -28,7 +28,7 @@ class GoodPlugin(DummyPlugin): class ThrowsExceptionOnLoadPlugin(DummyPlugin): - def load(self, *args, **kwargs): + def load(self, *args: Any, **kwargs: Any) -> None: super().load(*args, **kwargs) raise ValueError("controlled load fail") @@ -40,15 +40,15 @@ def __init__(self) -> None: class DummyPluginFinder(PluginFinder): - def __init__(self, specs: List[PluginSpec]): + def __init__(self, specs: list[PluginSpec]): self.specs = specs - def find_plugins(self) -> List[PluginSpec]: + def find_plugins(self) -> list[PluginSpec]: return self.specs @pytest.fixture -def dummy_plugin_finder(): +def dummy_plugin_finder() -> DummyPluginFinder: return DummyPluginFinder( [ PluginSpec("test.plugins.dummy", "shouldload", GoodPlugin), @@ -62,8 +62,8 @@ def dummy_plugin_finder(): class TestPluginManager: - def test_load_all(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_load_all(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) assert manager.is_loaded("shouldload") is False assert manager.is_loaded("shouldalsoload") is False @@ -84,8 +84,8 @@ def test_load_all(self, dummy_plugin_finder): assert type(plugins[0]) is GoodPlugin assert type(plugins[1]) is GoodPlugin - def test_list_names(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_list_names(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) names = manager.list_names() assert len(names) == 5 @@ -95,8 +95,8 @@ def test_list_names(self, dummy_plugin_finder): assert "init_errors" in names assert "shouldalsoload" in names - def test_exists(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_exists(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) assert manager.exists("shouldload") assert manager.exists("shouldnotload") @@ -105,7 +105,7 @@ def test_exists(self, dummy_plugin_finder): assert manager.exists("shouldalsoload") assert not manager.exists("foobar") - def test_load_all_load_is_only_called_once(self): + def test_load_all_load_is_only_called_once(self) -> None: finder = DummyPluginFinder( [ PluginSpec("test.plugins.dummy", "shouldload", GoodPlugin), @@ -113,7 +113,7 @@ def test_load_all_load_is_only_called_once(self): ] ) - manager: PluginManager[DummyPlugin] = PluginManager( + manager: PluginManager[DummyPlugin] = PluginManager[DummyPlugin]( "test.plugins.dummy", finder=finder, load_kwargs={"foo": "bar"} ) @@ -126,16 +126,16 @@ def test_load_all_load_is_only_called_once(self): assert len(plugins[0].load_calls) == 1 assert len(plugins[1].load_calls) == 1 - def test_load_on_non_existing_plugin(self): - manager = PluginManager("test.plugins.dummy", finder=DummyPluginFinder([])) + def test_load_on_non_existing_plugin(self) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=DummyPluginFinder([])) with pytest.raises(ValueError) as ex: manager.load("foo") ex.match("no plugin named foo in namespace test.plugins.dummy") - def test_load_all_container_has_errors(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_load_all_container_has_errors(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) c_shouldload = manager.get_container("shouldload") c_shouldnotload = manager.get_container("shouldnotload") @@ -157,8 +157,8 @@ def test_load_all_container_has_errors(self, dummy_plugin_finder): assert type(c_load_errors.load_error) is ValueError assert c_shouldalsoload.load_error is None - def test_load_all_propagate_exception(self): - manager = PluginManager( + def test_load_all_propagate_exception(self) -> None: + manager = PluginManager[DummyPlugin]( "test.plugins.dummy", finder=DummyPluginFinder( [ @@ -172,8 +172,8 @@ def test_load_all_propagate_exception(self): ex.match("controlled load fail") - def test_load_disabled_plugin(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_load_disabled_plugin(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) with pytest.raises(PluginDisabled) as ex: manager.load("shouldnotload") @@ -181,10 +181,12 @@ def test_load_disabled_plugin(self, dummy_plugin_finder): assert ex.value.namespace == "test.plugins.dummy" assert ex.value.name == "shouldnotload" - def test_lifecycle_listener(self, dummy_plugin_finder): - listener = MagicMock() + def test_lifecycle_listener(self, dummy_plugin_finder: DummyPluginFinder) -> None: + listener = Mock() - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder, listener=listener) + manager = PluginManager[DummyPlugin]( + "test.plugins.dummy", finder=dummy_plugin_finder, listener=listener + ) manager.load_all() assert listener.on_init_after.call_count == 4 @@ -196,8 +198,8 @@ def test_lifecycle_listener(self, dummy_plugin_finder): class TestGlobalPluginFilter: - def test_disable_namespace(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_disable_namespace(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) global_plugin_filter.add_exclusion(namespace="test.plugins.*") @@ -207,8 +209,8 @@ def test_disable_namespace(self, dummy_plugin_finder): global_plugin_filter.exclusions.clear() - def test_non_matching_namespace(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_non_matching_namespace(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) global_plugin_filter.add_exclusion(namespace="test.plugins.dummy.*") @@ -217,8 +219,8 @@ def test_non_matching_namespace(self, dummy_plugin_finder): assert manager.is_loaded("shouldalsoload") is True global_plugin_filter.exclusions.clear() - def test_disable_name(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_disable_name(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) global_plugin_filter.add_exclusion(name="*also*") @@ -227,8 +229,8 @@ def test_disable_name(self, dummy_plugin_finder): assert manager.is_loaded("shouldalsoload") is False global_plugin_filter.exclusions.clear() - def test_disable_value(self, dummy_plugin_finder): - manager = PluginManager("test.plugins.dummy", finder=dummy_plugin_finder) + def test_disable_value(self, dummy_plugin_finder: DummyPluginFinder) -> None: + manager = PluginManager[DummyPlugin]("test.plugins.dummy", finder=dummy_plugin_finder) global_plugin_filter.add_exclusion(value="tests.test_manager:*")