-
-
Notifications
You must be signed in to change notification settings - Fork 213
[ENH] Refactor Extension
#1590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ENH] Refactor Extension
#1590
Changes from all commits
2c0c1aa
2aab335
d368eab
6d3e0e9
1365bf6
67c0efb
e5850ef
00da7a9
7d33463
373fa53
e86fab7
e92156a
1945c58
5a1ccd6
c7e52e1
12df955
9e5e752
bf9a0aa
e3ca07d
a2aa2d0
bc541bb
34448aa
7bd15e5
3afc0d9
406a96f
5030914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # License: BSD 3-Clause | ||
|
|
||
| """Base classes for OpenML extensions.""" | ||
|
|
||
| from openml.extensions.base._connector import OpenMLAPIConnector | ||
| from openml.extensions.base._executor import ModelExecutor | ||
| from openml.extensions.base._serializer import ModelSerializer | ||
|
|
||
| __all__ = [ | ||
| "ModelExecutor", | ||
| "ModelSerializer", | ||
| "OpenMLAPIConnector", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # License: BSD 3-Clause | ||
|
|
||
| """Base class for OpenML API connectors.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openml.extensions.base import ModelExecutor, ModelSerializer | ||
|
|
||
|
|
||
| class OpenMLAPIConnector(ABC): | ||
| """Base class for OpenML API connectors.""" | ||
|
|
||
| @abstractmethod | ||
| def serializer(self) -> ModelSerializer: | ||
| """Return the serializer for this API.""" | ||
|
|
||
| @abstractmethod | ||
| def executor(self) -> ModelExecutor: | ||
| """Return the executor for this API.""" | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def supports(cls, estimator: Any) -> bool: | ||
| """High-level check if this connector supports the estimator instance or flow.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # License: BSD 3-Clause | ||
|
|
||
| """Base class for estimator executors.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from collections import OrderedDict | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
| import scipy.sparse | ||
|
|
||
| from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration | ||
| from openml.tasks.task import OpenMLTask | ||
|
|
||
|
|
||
| class ModelExecutor(ABC): | ||
| """Define runtime execution semantics for a specific API type.""" | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def can_handle_model(cls, model: Any) -> bool: | ||
| """Check whether a model flow can be handled by this extension. | ||
|
|
||
| This is typically done by checking the type of the model, or the package it belongs to. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def seed_model(self, model: Any, seed: int | None) -> Any: | ||
| """Set the seed of all the unseeded components of a model and return the seeded model. | ||
|
|
||
| Required so that all seed information can be uploaded to OpenML for reproducible results. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The model to be seeded | ||
| seed : int | ||
|
|
||
| Returns | ||
| ------- | ||
| model | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def _run_model_on_fold( # noqa: PLR0913 | ||
| self, | ||
| model: Any, | ||
| task: OpenMLTask, | ||
| X_train: np.ndarray | scipy.sparse.spmatrix, | ||
| rep_no: int, | ||
| fold_no: int, | ||
| y_train: np.ndarray | None = None, | ||
| X_test: np.ndarray | scipy.sparse.spmatrix | None = None, | ||
| ) -> tuple[np.ndarray, np.ndarray | None, OrderedDict[str, float], OpenMLRunTrace | None]: | ||
| """Run a model on a repeat, fold, subsample triplet of the task. | ||
|
|
||
| Returns the data that is necessary to construct the OpenML Run object. Is used by | ||
| :func:`openml.runs.run_flow_on_task`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The UNTRAINED model to run. The model instance will be copied and not altered. | ||
| task : OpenMLTask | ||
| The task to run the model on. | ||
| X_train : array-like | ||
| Training data for the given repetition and fold. | ||
| rep_no : int | ||
| The repeat of the experiment (0-based; in case of 1 time CV, always 0) | ||
| fold_no : int | ||
| The fold nr of the experiment (0-based; in case of holdout, always 0) | ||
| y_train : Optional[np.ndarray] (default=None) | ||
| Target attributes for supervised tasks. In case of classification, these are integer | ||
| indices to the potential classes specified by dataset. | ||
| X_test : Optional, array-like (default=None) | ||
| Test attributes to test for generalization in supervised tasks. | ||
|
|
||
| Returns | ||
| ------- | ||
| predictions : np.ndarray | ||
| Model predictions. | ||
| probabilities : Optional, np.ndarray | ||
| Predicted probabilities (only applicable for supervised classification tasks). | ||
| user_defined_measures : OrderedDict[str, float] | ||
| User defined measures that were generated on this fold | ||
| trace : Optional, OpenMLRunTrace | ||
| Hyperparameter optimization trace (only applicable for supervised tasks with | ||
| hyperparameter optimization). | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def check_if_model_fitted(self, model: Any) -> bool: | ||
| """Returns True/False denoting if the model has already been fitted/trained. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| """ | ||
|
|
||
| # Abstract methods for hyperparameter optimization | ||
|
|
||
| @abstractmethod | ||
| def instantiate_model_from_hpo_class( | ||
| self, | ||
| model: Any, | ||
| trace_iteration: OpenMLTraceIteration, | ||
| ) -> Any: | ||
| """Instantiate a base model which can be searched over by the hyperparameter optimization | ||
| model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| A hyperparameter optimization model which defines the model to be instantiated. | ||
| trace_iteration : OpenMLTraceIteration | ||
| Describing the hyperparameter settings to instantiate. | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| """ | ||
| # TODO a trace belongs to a run and therefore a flow -> simplify this part of the interface! | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # License: BSD 3-Clause | ||
|
|
||
| """Base class for estimator serializors.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openml.flows import OpenMLFlow | ||
|
|
||
|
|
||
| class ModelSerializer(ABC): | ||
| """Handle the conversion between estimator instances and OpenML Flows.""" | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def can_handle_model(cls, model: Any) -> bool: | ||
| """Check whether a model flow can be handled by this extension. | ||
|
|
||
| This is typically done by checking the type of the model, or the package it belongs to. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def model_to_flow(self, model: Any) -> OpenMLFlow: | ||
| """Transform a model to a flow for uploading it to OpenML. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
|
|
||
| Returns | ||
| ------- | ||
| OpenMLFlow | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def flow_to_model( | ||
| self, | ||
| flow: OpenMLFlow, | ||
| initialize_with_defaults: bool = False, # noqa: FBT002 | ||
| strict_version: bool = True, # noqa: FBT002 | ||
| ) -> Any: | ||
| """Instantiate a model from the flow representation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| flow : OpenMLFlow | ||
|
|
||
| initialize_with_defaults : bool, optional (default=False) | ||
| If this flag is set, the hyperparameter values of flows will be | ||
| ignored and a flow with its defaults is returned. | ||
|
|
||
| strict_version : bool, default=True | ||
| Whether to fail if version requirements are not fulfilled. | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def get_version_information(self) -> list[str]: | ||
| """Return dependency and version information.""" | ||
|
|
||
| @abstractmethod | ||
| def obtain_parameter_values( | ||
| self, | ||
| flow: OpenMLFlow, | ||
| model: Any = None, | ||
| ) -> list[dict[str, Any]]: | ||
| """Extracts all parameter settings required for the flow from the model. | ||
|
|
||
| If no explicit model is provided, the parameters will be extracted from `flow.model` | ||
| instead. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| flow : OpenMLFlow | ||
| OpenMLFlow object (containing flow ids, i.e., it has to be downloaded from the server) | ||
|
|
||
| model: Any, optional (default=None) | ||
| The model from which to obtain the parameter values. Must match the flow signature. | ||
| If None, use the model specified in ``OpenMLFlow.model``. | ||
|
|
||
| Returns | ||
| ------- | ||
| list | ||
| A list of dicts, where each dict has the following entries: | ||
| - ``oml:name`` : str: The OpenML parameter name | ||
| - ``oml:value`` : mixed: A representation of the parameter value | ||
| - ``oml:component`` : int: flow id to which the parameter belongs | ||
| """ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # License: BSD 3-Clause | ||
|
|
||
| """Extension registry.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from openml.exceptions import PyOpenMLError | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openml.extensions.base import OpenMLAPIConnector | ||
|
|
||
| API_CONNECTOR_REGISTRY: list[type[OpenMLAPIConnector]] = [ | ||
| # Add OpenMLAPIConnector subclasses here to register them | ||
| ] | ||
|
Comment on lines
+14
to
+16
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the expected access pattern?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No. There only needs to be a list of connectors. The If you look at the def flow_to_estimator(flow: OpenMLFlow, initialize_with_defaults: bool = False, strict_version: bool = True) -> Any:
connector = resolve_api_connector(flow)
return connector.serializer().flow_to_model(
flow,
initialize_with_defaults=initialize_with_defaults,
strict_version=strict_version,
) |
||
|
|
||
|
|
||
| def resolve_api_connector(estimator: Any) -> OpenMLAPIConnector: | ||
| """ | ||
| Identify and return the appropriate OpenML API connector for a given estimator. | ||
|
|
||
| This function iterates through the global ``API_CONNECTOR_REGISTRY`` to find | ||
| a connector class that supports the provided estimator instance or OpenML flow. | ||
| If a matching connector is found, it is instantiated and returned. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| estimator : Any | ||
| The estimator instance (e.g., a scikit-learn estimator) or OpenML flow for | ||
| which an API connector is required. | ||
|
|
||
| Returns | ||
| ------- | ||
| OpenMLAPIConnector | ||
| An instance of the matching API connector. | ||
|
|
||
| Raises | ||
| ------ | ||
| OpenMLException | ||
| If no connector is found in the registry that supports the provided | ||
| model, or if multiple connectors in the registry claim support for | ||
| the provided model. | ||
| """ | ||
| for connector_cls in API_CONNECTOR_REGISTRY: | ||
| if connector_cls.supports(estimator): | ||
| return connector_cls() | ||
|
|
||
| raise PyOpenMLError("No OpenML API connector supports this estimator.") | ||
Uh oh!
There was an error while loading. Please reload this page.