Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ clean-dist: clean
rm -rf dist/

lint: venv
$(VENV_ACTIVATE); python -m ruff check .
$(VENV_ACTIVATE); python -m ruff check . && python -m mypy

format: venv
$(VENV_ACTIVATE); python -m ruff format . && python -m ruff check . --fix
Expand Down
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[mypy]
explicit_package_bases = true
files=plux/runtime/manager.py,tests/test_manager.py
ignore_missing_imports = False
follow_imports = silent
ignore_errors = False
disallow_untyped_defs = True
disallow_untyped_calls = True
disallow_any_generics = True
disallow_subclassing_any = True
warn_unused_ignores = True
2 changes: 1 addition & 1 deletion plux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
"PluginSpecResolver",
"PluginType",
"plugin",
"__version__"
"__version__",
]
2 changes: 1 addition & 1 deletion plux/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PluginDisabled(PluginException):

reason: str

def __init__(self, namespace: str, name: str, reason: str = None):
def __init__(self, namespace: str, name: str, reason: str | None = None):
message = f"plugin {namespace}:{name} is disabled"
if reason:
message = f"{message}, reason: {reason}"
Expand Down
78 changes: 44 additions & 34 deletions plux/runtime/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import threading
import typing as t
from collections.abc import Iterable
from importlib.metadata import EntryPoint

from plux.core.plugin import (
Plugin,
Expand All @@ -18,9 +20,10 @@
LOG = logging.getLogger(__name__)

P = t.TypeVar("P", bound=Plugin)
PS = t.ParamSpec("PS")


def _call_safe(func: t.Callable, args: tuple, exception_message: str):
def _call_safe(func: t.Callable[PS, None], args: t.Any, exception_message: str) -> None:
"""
Call the given function with the given arguments, and if it fails, log the given exception_message. If
logging.DEBUG is set for the logger, then we also log the traceback. An exception is made for any
Expand All @@ -32,7 +35,7 @@ def _call_safe(func: t.Callable, args: tuple, exception_message: str):
:return: whatever the func returns
"""
try:
return func(*args)
func(*args, **{})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the invocations of _call_safe do anything with the result, so I think this is a safe refactor/improvement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, could we specify the Callable as [PS, None]? Or - to keep the return, use -> T | None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that makes sense - I changed it to [PS, None]. If we need the result in the future, we can always change it to -> T | None, considering it's a private method.

except PluginException:
# re-raise PluginExceptions, since they should be handled by the caller
raise
Expand All @@ -54,23 +57,25 @@ class PluginLifecycleNotifierMixin:

listeners: list[PluginLifecycleListener]

def _fire_on_resolve_after(self, plugin_spec):
def _fire_on_resolve_after(self, plugin_spec: PluginSpec) -> None:
for listener in self.listeners:
_call_safe(
listener.on_resolve_after,
(plugin_spec,), #
"error while calling on_resolve_after",
)

def _fire_on_resolve_exception(self, namespace, entrypoint, exception):
def _fire_on_resolve_exception(
self, namespace: str, entrypoint: EntryPoint, exception: Exception
) -> None:
for listener in self.listeners:
_call_safe(
listener.on_resolve_exception,
(namespace, entrypoint, exception),
"error while calling on_resolve_exception",
)

def _fire_on_init_after(self, plugin_spec, plugin):
def _fire_on_init_after(self, plugin_spec: PluginSpec, plugin: P) -> None:
for listener in self.listeners:
_call_safe(
listener.on_init_after,
Expand All @@ -81,31 +86,35 @@ def _fire_on_init_after(self, plugin_spec, plugin):
"error while calling on_init_after",
)

def _fire_on_init_exception(self, plugin_spec, exception):
def _fire_on_init_exception(self, plugin_spec: PluginSpec, exception: Exception) -> None:
for listener in self.listeners:
_call_safe(
listener.on_init_exception,
(plugin_spec, exception),
"error while calling on_init_exception",
)

def _fire_on_load_before(self, plugin_spec, plugin, load_args, load_kwargs):
def _fire_on_load_before(
self, plugin_spec: PluginSpec, plugin: P, load_args: t.Any, load_kwargs: t.Any
) -> None:
for listener in self.listeners:
_call_safe(
listener.on_load_before,
(plugin_spec, plugin, load_args, load_kwargs),
"error while calling on_load_before",
)

def _fire_on_load_after(self, plugin_spec, plugin, result):
def _fire_on_load_after(self, plugin_spec: PluginSpec, plugin: P | None, result: t.Any) -> None:
for listener in self.listeners:
_call_safe(
listener.on_load_after,
(plugin_spec, plugin, result),
"error while calling on_load_after",
)

def _fire_on_load_exception(self, plugin_spec, plugin, exception):
def _fire_on_load_exception(
self, plugin_spec: PluginSpec, plugin: P | None, exception: Exception
) -> None:
for listener in self.listeners:
_call_safe(
listener.on_load_exception,
Expand All @@ -123,20 +132,20 @@ class PluginContainer(t.Generic[P]):
lock: threading.RLock

plugin_spec: PluginSpec
plugin: P = None
load_value: t.Any | None = None
plugin: P | None = None
load_value: t.Any = None

is_init: bool = False
is_loaded: bool = False

init_error: Exception = None
load_error: Exception = None
init_error: Exception | None = None
load_error: Exception | None = None

is_disabled: bool = False
disabled_reason = str = None
disabled_reason: str | None = None

@property
def distribution(self) -> Distribution:
def distribution(self) -> Distribution | None:
"""
Uses metadata from importlib to resolve the distribution information for this plugin.

Expand All @@ -160,19 +169,19 @@ class PluginManager(PluginLifecycleNotifierMixin, t.Generic[P]):

namespace: str

load_args: list | tuple
load_args: list[t.Any] | tuple[t.Any, ...]
load_kwargs: dict[str, t.Any]
listeners: list[PluginLifecycleListener]
filters: list[PluginFilter]

def __init__(
self,
namespace: str,
load_args: list | tuple = None,
load_kwargs: dict = None,
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] = None,
finder: PluginFinder = None,
filters: list[PluginFilter] = None,
load_args: list[t.Any] | tuple[t.Any, ...] | None = None,
load_kwargs: dict[str, t.Any] | None = None,
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] | None = None,
finder: PluginFinder | None = None,
filters: list[PluginFilter] | None = None,
):
"""
Create a new ``PluginManager`` that can be used to load plugins. The simplest ``PluginManager`` only needs
Expand Down Expand Up @@ -231,7 +240,7 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t
self.load_kwargs = load_kwargs or dict()

if listener:
if isinstance(listener, (list, set, tuple)):
if isinstance(listener, Iterable):
self.listeners = list(listener)
else:
self.listeners = [listener]
Expand All @@ -243,10 +252,10 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t

self.finder = finder or MetadataPluginFinder(self.namespace, self._fire_on_resolve_exception)

self._plugin_index = None
self._plugin_index: dict[str, PluginContainer[P]] | None = None
self._init_mutex = threading.RLock()

def add_listener(self, listener: PluginLifecycleListener):
def add_listener(self, listener: PluginLifecycleListener) -> None:
"""
Adds a lifecycle listener to the plugin manager. The listener will be notified of plugin lifecycle events.

Expand Down Expand Up @@ -326,12 +335,12 @@ def load(self, name: str) -> P:
if container.load_error:
raise container.load_error

if not container.is_loaded:
if container.plugin is None or not container.is_loaded:
raise PluginException("plugin did not load correctly", namespace=self.namespace, name=name)

return container.plugin

def load_all(self, propagate_exceptions=False) -> list[P]:
def load_all(self, propagate_exceptions: bool = False) -> list[P]:
"""
Attempts to load all plugins found in the namespace and returns those that were loaded successfully.

Expand Down Expand Up @@ -364,10 +373,10 @@ def load_all(self, propagate_exceptions=False) -> list[P]:
:param propagate_exceptions: If True, re-raises any exceptions encountered during loading
:return: A list of successfully loaded plugin instances
"""
plugins = list()
plugins: list[P] = list()

for name, container in self._plugins.items():
if container.is_loaded:
if container.plugin is not None and container.is_loaded:
plugins.append(container.plugin)
continue

Expand Down Expand Up @@ -552,7 +561,7 @@ def _require_plugin(self, name: str) -> PluginContainer[P]:

return self._plugins[name]

def _load_plugin(self, container: PluginContainer) -> None:
def _load_plugin(self, container: PluginContainer[P]) -> None:
"""
Implements the core algorithm to load a plugin from a ``PluginSpec`` (contained in the ``PluginContainer``),
and stores all relevant results, such as the Plugin instance, load result, or any errors into the passed
Expand Down Expand Up @@ -602,6 +611,7 @@ def _load_plugin(self, container: PluginContainer) -> None:
return

plugin = container.plugin
assert plugin # Make MyPy happy - plugin should exist at this point

if not plugin.should_load():
raise PluginDisabled(
Expand Down Expand Up @@ -643,9 +653,9 @@ def _plugin_from_spec(self, plugin_spec: PluginSpec) -> P:
if spec:
factory = spec.factory

return factory()
return factory() # type: ignore[return-value]

def _init_plugin_index(self) -> dict[str, PluginContainer]:
def _init_plugin_index(self) -> dict[str, PluginContainer[P]]:
"""
Initializes the plugin index, which maps plugin names to plugin containers. This method will *resolve* plugins,
meaning it loads the entry point object reference, thereby importing all its code.
Expand All @@ -654,7 +664,7 @@ def _init_plugin_index(self) -> dict[str, PluginContainer]:
"""
return {plugin.name: plugin for plugin in self._import_plugins() if plugin}

def _import_plugins(self) -> t.Iterable[PluginContainer]:
def _import_plugins(self) -> t.Iterable[PluginContainer[P]]:
"""
Finds all ``PluginSpace`` instances in the namespace, creates a container for each spec, and yields them one
by one. The plugin finder will typically load the entry point which involves importing the module it lives in.
Expand All @@ -671,14 +681,14 @@ def _import_plugins(self) -> t.Iterable[PluginContainer]:

yield self._create_container(spec)

def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer:
def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer[P]:
"""
Factory method to create a ``PluginContainer`` for the given ``PluginSpec``.

:param plugin_spec: The ``PluginSpec`` to create a container for.
:return: A new ``PluginContainer`` with the basic information of the plugin spec.
"""
container = PluginContainer()
container = PluginContainer[P]()
container.lock = threading.RLock()
container.name = plugin_spec.name
container.plugin_spec = plugin_spec
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dev = [
"setuptools",
"pytest==8.4.1",
"ruff==0.9.1",
"mypy",
]

[tool.hatch.build.hooks.vcs]
Expand Down
Loading