Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2c0c1aa
[ENH] Refactor `Extension`
jgyasu Jan 2, 2026
2aab335
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2026
d368eab
Merge remote-tracking branch 'upstream/main' into refactor-extension
jgyasu Jan 2, 2026
6d3e0e9
Merge remote-tracking branch 'refs/remotes/origin/refactor-extension'…
jgyasu Jan 2, 2026
1365bf6
correct openml exception
jgyasu Jan 2, 2026
67c0efb
use __all__ for imports in __init__
jgyasu Jan 2, 2026
e5850ef
update registry
jgyasu Jan 2, 2026
00da7a9
update registry and file structure
jgyasu Jan 5, 2026
7d33463
Merge branch 'main' into refactor-extension
jgyasu Jan 5, 2026
373fa53
[DO NOT MERGE] Refactor openml-sklearn back into openml-python
jgyasu Jan 5, 2026
e86fab7
add public function for serialisation and deserialisation
jgyasu Jan 5, 2026
e92156a
move the flow utils to flows/functions.py
jgyasu Jan 5, 2026
1945c58
update flows
jgyasu Jan 5, 2026
5a1ccd6
expose parameters of flow_to_model
jgyasu Jan 5, 2026
c7e52e1
remove sklearn
jgyasu Jan 5, 2026
12df955
remove .DS_Store
jgyasu Jan 5, 2026
9e5e752
add flow functions to __init__.py
jgyasu Jan 5, 2026
bf9a0aa
add tests for extension base classes and registry
jgyasu Jan 6, 2026
e3ca07d
remove sklearn extension from registry temporarily
jgyasu Jan 7, 2026
a2aa2d0
Merge branch 'main' into pr/1590
fkiraly Jan 7, 2026
bc541bb
Merge branch 'main' into refactor-extension
jgyasu Jan 26, 2026
34448aa
Merge remote-tracking branch 'refs/remotes/origin/refactor-extension'…
jgyasu Jan 26, 2026
7bd15e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2026
3afc0d9
Merge branch 'main' into refactor-extension
jgyasu Jan 26, 2026
406a96f
move some methods between serializer and executor
jgyasu Jan 26, 2026
5030914
update tests
jgyasu Jan 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions openml/extensions/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# License: BSD 3-Clause

"""Base classes for OpenML extensions."""

from openml.extensions.base._connector import OpenMLAPIConnector
from openml.extensions.base._executor import ModelExecutor
from openml.extensions.base._serializer import ModelSerializer

__all__ = [
"ModelExecutor",
"ModelSerializer",
"OpenMLAPIConnector",
]
28 changes: 28 additions & 0 deletions openml/extensions/base/_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# License: BSD 3-Clause

"""Base class for OpenML API connectors."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from openml.extensions.base import ModelExecutor, ModelSerializer


class OpenMLAPIConnector(ABC):
"""Base class for OpenML API connectors."""

@abstractmethod
def serializer(self) -> ModelSerializer:
"""Return the serializer for this API."""

@abstractmethod
def executor(self) -> ModelExecutor:
"""Return the executor for this API."""

@classmethod
@abstractmethod
def supports(cls, estimator: Any) -> bool:
"""High-level check if this connector supports the estimator instance or flow."""
137 changes: 137 additions & 0 deletions openml/extensions/base/_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# License: BSD 3-Clause

"""Base class for estimator executors."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
import numpy as np
import scipy.sparse

from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration
from openml.tasks.task import OpenMLTask


class ModelExecutor(ABC):
"""Define runtime execution semantics for a specific API type."""

@classmethod
@abstractmethod
def can_handle_model(cls, model: Any) -> bool:
"""Check whether a model flow can be handled by this extension.

This is typically done by checking the type of the model, or the package it belongs to.

Parameters
----------
model : Any

Returns
-------
bool
"""

@abstractmethod
def seed_model(self, model: Any, seed: int | None) -> Any:
"""Set the seed of all the unseeded components of a model and return the seeded model.

Required so that all seed information can be uploaded to OpenML for reproducible results.

Parameters
----------
model : Any
The model to be seeded
seed : int

Returns
-------
model
"""

@abstractmethod
def _run_model_on_fold( # noqa: PLR0913
self,
model: Any,
task: OpenMLTask,
X_train: np.ndarray | scipy.sparse.spmatrix,
rep_no: int,
fold_no: int,
y_train: np.ndarray | None = None,
X_test: np.ndarray | scipy.sparse.spmatrix | None = None,
) -> tuple[np.ndarray, np.ndarray | None, OrderedDict[str, float], OpenMLRunTrace | None]:
"""Run a model on a repeat, fold, subsample triplet of the task.

Returns the data that is necessary to construct the OpenML Run object. Is used by
:func:`openml.runs.run_flow_on_task`.

Parameters
----------
model : Any
The UNTRAINED model to run. The model instance will be copied and not altered.
task : OpenMLTask
The task to run the model on.
X_train : array-like
Training data for the given repetition and fold.
rep_no : int
The repeat of the experiment (0-based; in case of 1 time CV, always 0)
fold_no : int
The fold nr of the experiment (0-based; in case of holdout, always 0)
y_train : Optional[np.ndarray] (default=None)
Target attributes for supervised tasks. In case of classification, these are integer
indices to the potential classes specified by dataset.
X_test : Optional, array-like (default=None)
Test attributes to test for generalization in supervised tasks.

Returns
-------
predictions : np.ndarray
Model predictions.
probabilities : Optional, np.ndarray
Predicted probabilities (only applicable for supervised classification tasks).
user_defined_measures : OrderedDict[str, float]
User defined measures that were generated on this fold
trace : Optional, OpenMLRunTrace
Hyperparameter optimization trace (only applicable for supervised tasks with
hyperparameter optimization).
"""

@abstractmethod
def check_if_model_fitted(self, model: Any) -> bool:
"""Returns True/False denoting if the model has already been fitted/trained.

Parameters
----------
model : Any

Returns
-------
bool
"""

# Abstract methods for hyperparameter optimization

@abstractmethod
def instantiate_model_from_hpo_class(
self,
model: Any,
trace_iteration: OpenMLTraceIteration,
) -> Any:
"""Instantiate a base model which can be searched over by the hyperparameter optimization
model.

Parameters
----------
model : Any
A hyperparameter optimization model which defines the model to be instantiated.
trace_iteration : OpenMLTraceIteration
Describing the hyperparameter settings to instantiate.

Returns
-------
Any
"""
# TODO a trace belongs to a run and therefore a flow -> simplify this part of the interface!
102 changes: 102 additions & 0 deletions openml/extensions/base/_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# License: BSD 3-Clause

"""Base class for estimator serializors."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from openml.flows import OpenMLFlow


class ModelSerializer(ABC):
"""Handle the conversion between estimator instances and OpenML Flows."""

@classmethod
@abstractmethod
def can_handle_model(cls, model: Any) -> bool:
"""Check whether a model flow can be handled by this extension.

This is typically done by checking the type of the model, or the package it belongs to.

Parameters
----------
model : Any

Returns
-------
bool
"""

@abstractmethod
def model_to_flow(self, model: Any) -> OpenMLFlow:
"""Transform a model to a flow for uploading it to OpenML.

Parameters
----------
model : Any

Returns
-------
OpenMLFlow
"""

@abstractmethod
def flow_to_model(
self,
flow: OpenMLFlow,
initialize_with_defaults: bool = False, # noqa: FBT002
strict_version: bool = True, # noqa: FBT002
) -> Any:
"""Instantiate a model from the flow representation.

Parameters
----------
flow : OpenMLFlow

initialize_with_defaults : bool, optional (default=False)
If this flag is set, the hyperparameter values of flows will be
ignored and a flow with its defaults is returned.

strict_version : bool, default=True
Whether to fail if version requirements are not fulfilled.

Returns
-------
Any
"""

@abstractmethod
def get_version_information(self) -> list[str]:
"""Return dependency and version information."""

@abstractmethod
def obtain_parameter_values(
self,
flow: OpenMLFlow,
model: Any = None,
) -> list[dict[str, Any]]:
"""Extracts all parameter settings required for the flow from the model.

If no explicit model is provided, the parameters will be extracted from `flow.model`
instead.

Parameters
----------
flow : OpenMLFlow
OpenMLFlow object (containing flow ids, i.e., it has to be downloaded from the server)

model: Any, optional (default=None)
The model from which to obtain the parameter values. Must match the flow signature.
If None, use the model specified in ``OpenMLFlow.model``.

Returns
-------
list
A list of dicts, where each dict has the following entries:
- ``oml:name`` : str: The OpenML parameter name
- ``oml:value`` : mixed: A representation of the parameter value
- ``oml:component`` : int: flow id to which the parameter belongs
"""
49 changes: 49 additions & 0 deletions openml/extensions/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# License: BSD 3-Clause

"""Extension registry."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from openml.exceptions import PyOpenMLError

if TYPE_CHECKING:
from openml.extensions.base import OpenMLAPIConnector

API_CONNECTOR_REGISTRY: list[type[OpenMLAPIConnector]] = [
# Add OpenMLAPIConnector subclasses here to register them
]
Comment on lines +14 to +16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected access pattern?
And why is this now at the APIConnector level instead of the Executor and Serializer levels? Wouldn't I expect a list of each now?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't I expect a list of each now?

No. There only needs to be a list of connectors. The resolve_api_connector resolves the connector by calling the supports method of the connectors and returns a compatible connector instance, from which the serializer and executor can be accessed. The supports method of the connectors will call the can_handle_model of serializer and executor.

If you look at the flow_to_estimator function that I have implemented, this is how it goes:

def flow_to_estimator(flow: OpenMLFlow, initialize_with_defaults: bool = False, strict_version: bool = True) -> Any:

    connector = resolve_api_connector(flow)

    return connector.serializer().flow_to_model(
        flow,
        initialize_with_defaults=initialize_with_defaults,
        strict_version=strict_version,
    )



def resolve_api_connector(estimator: Any) -> OpenMLAPIConnector:
"""
Identify and return the appropriate OpenML API connector for a given estimator.

This function iterates through the global ``API_CONNECTOR_REGISTRY`` to find
a connector class that supports the provided estimator instance or OpenML flow.
If a matching connector is found, it is instantiated and returned.

Parameters
----------
estimator : Any
The estimator instance (e.g., a scikit-learn estimator) or OpenML flow for
which an API connector is required.

Returns
-------
OpenMLAPIConnector
An instance of the matching API connector.

Raises
------
OpenMLException
If no connector is found in the registry that supports the provided
model, or if multiple connectors in the registry claim support for
the provided model.
"""
for connector_cls in API_CONNECTOR_REGISTRY:
if connector_cls.supports(estimator):
return connector_cls()

raise PyOpenMLError("No OpenML API connector supports this estimator.")
7 changes: 5 additions & 2 deletions openml/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# License: BSD 3-Clause

from .flow import OpenMLFlow
from .functions import (
from openml.flows.flow import OpenMLFlow
from openml.flows.functions import (
assert_flows_equal,
delete_flow,
flow_exists,
get_flow,
get_flow_id,
list_flows,
)
from openml.flows.utils import estimator_to_flow, flow_to_estimator

__all__ = [
"OpenMLFlow",
"assert_flows_equal",
"delete_flow",
"estimator_to_flow",
"flow_exists",
"flow_to_estimator",
"get_flow",
"get_flow_id",
"list_flows",
Expand Down
Loading
Loading