Skip to content

Commit 8b771e6

Browse files
committed
feat(core): implement a more robust set_params()
1 parent cfd6dc1 commit 8b771e6

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed

ibis_ml/core.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,44 @@ def get_params(self, deep=True) -> dict[str, Any]:
370370
"""
371371
return {key: getattr(self, key) for key in self._get_param_names()}
372372

373+
def set_params(self, **params):
374+
"""Set the parameters of this estimator.
375+
376+
Parameters
377+
----------
378+
**params : dict
379+
Step parameters.
380+
381+
Returns
382+
-------
383+
self : object
384+
Step class instance.
385+
386+
Notes
387+
-----
388+
Derived from [1]_.
389+
390+
References
391+
----------
392+
.. [1] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
393+
"""
394+
if not params:
395+
# Simple optimization to gain speed (inspect is slow)
396+
return self
397+
398+
valid_params = self._get_param_names()
399+
400+
for key, value in params.items():
401+
if key not in valid_params:
402+
raise ValueError(
403+
f"Invalid parameter {key!r} for estimator {self}. "
404+
f"Valid parameters are: {valid_params!r}."
405+
)
406+
407+
setattr(self, key, value)
408+
409+
return self
410+
373411
def __repr__(self) -> str:
374412
return pprint.pformat(self)
375413

@@ -453,7 +491,7 @@ def _name_estimators(estimators):
453491

454492
class Recipe:
455493
def __init__(self, *steps: Step):
456-
self.steps = steps
494+
self.steps = list(steps)
457495
self._output_format = "default"
458496

459497
def __repr__(self):
@@ -502,9 +540,69 @@ def get_params(self, deep=True) -> dict[str, Any]:
502540
out[f"{name}__{key}"] = value
503541
return out
504542

505-
def set_params(self, **kwargs):
506-
if "steps" in kwargs:
507-
self.steps = kwargs.get("steps")
543+
def set_params(self, **params):
544+
"""Set the parameters of this estimator.
545+
546+
Valid parameter keys can be listed with ``get_params()``. Note that
547+
you can directly set the parameters of the estimators contained in
548+
`steps`.
549+
550+
Parameters
551+
----------
552+
**params : dict
553+
Parameters of this estimator or parameters of estimators contained
554+
in `steps`. Parameters of the steps may be set using its name and
555+
the parameter name separated by a '__'.
556+
557+
Returns
558+
-------
559+
self : object
560+
Recipe class instance.
561+
562+
Notes
563+
-----
564+
Derived from [1]_ and [2]_.
565+
566+
References
567+
----------
568+
.. [1] https://github.com/scikit-learn/scikit-learn/blob/ff1c6f3/sklearn/utils/metaestimators.py#L51-L70
569+
.. [2] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
570+
"""
571+
if not params:
572+
# Simple optimization to gain speed (inspect is slow)
573+
return self
574+
575+
# Ensure strict ordering of parameter setting:
576+
# 1. All steps
577+
if "steps" in params:
578+
self.steps = params.pop("steps")
579+
580+
# 2. Replace items with estimators in params
581+
estimator_name_indexes = {
582+
x: i for i, x in enumerate(name for name, _ in _name_estimators(self.steps))
583+
}
584+
for name in list(params):
585+
if "__" not in name and name in estimator_name_indexes:
586+
self.steps[estimator_name_indexes[name]] = params.pop(name)
587+
588+
# 3. Step parameters and other initialisation arguments
589+
valid_params = self.get_params(deep=True)
590+
591+
nested_params = defaultdict(dict) # grouped by prefix
592+
for key, value in params.items():
593+
key, sub_key = key.split("__", maxsplit=1)
594+
if key not in valid_params:
595+
raise ValueError(
596+
f"Invalid parameter {key!r} for estimator {self}. "
597+
f"Valid parameters are: ['steps']."
598+
)
599+
600+
nested_params[key][sub_key] = value
601+
602+
for key, sub_params in nested_params.items():
603+
valid_params[key].set_params(**sub_params)
604+
605+
return self
508606

509607
def set_output(
510608
self,

tests/test_core.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
import ibis
24
import ibis.expr.types as ir
35
import numpy as np
@@ -372,6 +374,55 @@ def test_get_params():
372374
assert "expandtimestamp__components" not in rec.get_params(deep=False)
373375

374376

377+
def test_set_params():
378+
rec = ml.Recipe(ml.ExpandTimestamp(ml.timestamp()))
379+
380+
# Nonexistent parameter in step
381+
with pytest.raises(
382+
ValueError,
383+
match="Invalid parameter 'nonexistent_param' for estimator ExpandTimestamp",
384+
):
385+
rec.set_params(expandtimestamp__nonexistent_param=True)
386+
387+
# Nonexistent parameter of pipeline
388+
with pytest.raises(
389+
ValueError, match="Invalid parameter 'expanddatetime' for estimator Recipe"
390+
):
391+
rec.set_params(expanddatetime__nonexistent_param=True)
392+
393+
394+
def test_set_params_passes_all_parameters():
395+
# Make sure all parameters are passed together to set_params
396+
# of nested estimator.
397+
rec = ml.Recipe(ml.ExpandTimestamp(ml.timestamp()))
398+
with patch.object(ml.ExpandTimestamp, "set_params") as mock_set_params:
399+
rec.set_params(
400+
expandtimestamp__inputs=["x", "y"],
401+
expandtimestamp__components=["day", "year", "hour"],
402+
)
403+
404+
mock_set_params.assert_called_once_with(
405+
inputs=["x", "y"], components=["day", "year", "hour"]
406+
)
407+
408+
409+
def test_set_params_updates_valid_params():
410+
# Check that set_params tries to set `replacement_mutateat.inputs`, not
411+
# `original_mutateat.inputs`.
412+
original_mutateat = ml.MutateAt("dep_time", ibis._.hour() * 60 + ibis._.minute()) # noqa: SLF001
413+
rec = ml.Recipe(
414+
original_mutateat,
415+
ml.MutateAt(ml.timestamp(), ibis._.epoch_seconds()), # noqa: SLF001
416+
)
417+
replacement_mutateat = ml.MutateAt("arr_time", ibis._.hour() * 60 + ibis._.minute()) # noqa: SLF001
418+
rec.set_params(
419+
**{"mutateat-1": replacement_mutateat, "mutateat-1__inputs": ml.cols("arrival")}
420+
)
421+
assert original_mutateat.inputs == ml.cols("dep_time")
422+
assert replacement_mutateat.inputs == ml.cols("arrival")
423+
assert rec.steps[0] is replacement_mutateat
424+
425+
375426
@pytest.mark.parametrize(
376427
("step", "url"),
377428
[

0 commit comments

Comments
 (0)