Skip to content

Commit 6a1ec0a

Browse files
committed
refactor(core): privatize step param getter/setter
1 parent 8b771e6 commit 6a1ec0a

File tree

3 files changed

+26
-32
lines changed

3 files changed

+26
-32
lines changed

ibis_ml/core.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -347,13 +347,8 @@ def _get_param_names(cls) -> list[str]:
347347
# Extract and sort argument names excluding 'self'
348348
return sorted([p.name for p in parameters])
349349

350-
def get_params(self, deep=True) -> dict[str, Any]:
351-
"""Get parameters for this estimator.
352-
353-
Parameters
354-
----------
355-
deep : bool, default=True
356-
Has no effect, because steps cannot contain nested substeps.
350+
def _get_params(self) -> dict[str, Any]:
351+
"""Get parameters for this step.
357352
358353
Returns
359354
-------
@@ -370,8 +365,8 @@ def get_params(self, deep=True) -> dict[str, Any]:
370365
"""
371366
return {key: getattr(self, key) for key in self._get_param_names()}
372367

373-
def set_params(self, **params):
374-
"""Set the parameters of this estimator.
368+
def _set_params(self, **params):
369+
"""Set the parameters of this step.
375370
376371
Parameters
377372
----------
@@ -400,7 +395,7 @@ def set_params(self, **params):
400395
for key, value in params.items():
401396
if key not in valid_params:
402397
raise ValueError(
403-
f"Invalid parameter {key!r} for estimator {self}. "
398+
f"Invalid parameter {key!r} for step {self}. "
404399
f"Valid parameters are: {valid_params!r}."
405400
)
406401

@@ -503,16 +498,16 @@ def output_format(self) -> Literal["default", "pandas", "pyarrow", "polars"]:
503498
return self._output_format
504499

505500
def get_params(self, deep=True) -> dict[str, Any]:
506-
"""Get parameters for this estimator.
501+
"""Get parameters for this recipe.
507502
508503
Returns the parameters given in the constructor as well as the
509-
estimators contained within the `steps` of the `Recipe`.
504+
steps contained within the `steps` of the `Recipe`.
510505
511506
Parameters
512507
----------
513508
deep : bool, default=True
514-
If True, will return the parameters for this estimator and
515-
contained subobjects that are estimators.
509+
If True, will return the parameters for this recipe and
510+
contained steps.
516511
517512
Returns
518513
-------
@@ -531,26 +526,25 @@ def get_params(self, deep=True) -> dict[str, Any]:
531526
if not deep:
532527
return out
533528

534-
estimators = _name_estimators(self.steps)
535-
out.update(estimators)
529+
steps = _name_estimators(self.steps)
530+
out.update(steps)
536531

537-
for name, estimator in estimators:
538-
if hasattr(estimator, "get_params"):
539-
for key, value in estimator.get_params(deep=True).items():
540-
out[f"{name}__{key}"] = value
532+
for name, step in steps:
533+
for key, value in step._get_params().items(): # noqa: SLF001
534+
out[f"{name}__{key}"] = value
541535
return out
542536

543537
def set_params(self, **params):
544-
"""Set the parameters of this estimator.
538+
"""Set the parameters of this recipe.
545539
546540
Valid parameter keys can be listed with ``get_params()``. Note that
547-
you can directly set the parameters of the estimators contained in
541+
you can directly set the parameters of the steps contained in
548542
`steps`.
549543
550544
Parameters
551545
----------
552546
**params : dict
553-
Parameters of this estimator or parameters of estimators contained
547+
Parameters of this recipe or parameters of steps contained
554548
in `steps`. Parameters of the steps may be set using its name and
555549
the parameter name separated by a '__'.
556550
@@ -577,7 +571,7 @@ def set_params(self, **params):
577571
if "steps" in params:
578572
self.steps = params.pop("steps")
579573

580-
# 2. Replace items with estimators in params
574+
# 2. Replace steps with steps in params
581575
estimator_name_indexes = {
582576
x: i for i, x in enumerate(name for name, _ in _name_estimators(self.steps))
583577
}
@@ -593,14 +587,14 @@ def set_params(self, **params):
593587
key, sub_key = key.split("__", maxsplit=1)
594588
if key not in valid_params:
595589
raise ValueError(
596-
f"Invalid parameter {key!r} for estimator {self}. "
590+
f"Invalid parameter {key!r} for recipe {self}. "
597591
f"Valid parameters are: ['steps']."
598592
)
599593

600594
nested_params[key][sub_key] = value
601595

602596
for key, sub_params in nested_params.items():
603-
valid_params[key].set_params(**sub_params)
597+
valid_params[key]._set_params(**sub_params) # noqa: SLF001
604598

605599
return self
606600

tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_mutate_at_expr():
3434
res = step.transform_table(t)
3535
sol = t.mutate(x=_.x.abs(), y=_.y.abs())
3636
assert res.equals(sol)
37-
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]
37+
assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001
3838

3939

4040
def test_mutate_at_named_exprs():
@@ -45,7 +45,7 @@ def test_mutate_at_named_exprs():
4545
res = step.transform_table(t)
4646
sol = t.mutate(x=_.x.abs(), y=_.y.abs(), x_log=_.x.log(), y_log=_.y.log())
4747
assert res.equals(sol)
48-
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]
48+
assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001
4949

5050

5151
def test_mutate():
@@ -56,4 +56,4 @@ def test_mutate():
5656
res = step.transform_table(t)
5757
sol = t.mutate(_.x.abs().name("x_abs"), y_log=lambda t: t.y.log())
5858
assert res.equals(sol)
59-
assert list(step.get_params()) == ["exprs", "named_exprs"]
59+
assert list(step._get_params()) == ["exprs", "named_exprs"] # noqa: SLF001

tests/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,13 @@ def test_set_params():
380380
# Nonexistent parameter in step
381381
with pytest.raises(
382382
ValueError,
383-
match="Invalid parameter 'nonexistent_param' for estimator ExpandTimestamp",
383+
match="Invalid parameter 'nonexistent_param' for step ExpandTimestamp",
384384
):
385385
rec.set_params(expandtimestamp__nonexistent_param=True)
386386

387387
# Nonexistent parameter of pipeline
388388
with pytest.raises(
389-
ValueError, match="Invalid parameter 'expanddatetime' for estimator Recipe"
389+
ValueError, match="Invalid parameter 'expanddatetime' for recipe Recipe"
390390
):
391391
rec.set_params(expanddatetime__nonexistent_param=True)
392392

@@ -395,7 +395,7 @@ def test_set_params_passes_all_parameters():
395395
# Make sure all parameters are passed together to set_params
396396
# of nested estimator.
397397
rec = ml.Recipe(ml.ExpandTimestamp(ml.timestamp()))
398-
with patch.object(ml.ExpandTimestamp, "set_params") as mock_set_params:
398+
with patch.object(ml.ExpandTimestamp, "_set_params") as mock_set_params:
399399
rec.set_params(
400400
expandtimestamp__inputs=["x", "y"],
401401
expandtimestamp__components=["day", "year", "hour"],

0 commit comments

Comments
 (0)