@@ -370,6 +370,44 @@ def get_params(self, deep=True) -> dict[str, Any]:
370370 """
371371 return {key : getattr (self , key ) for key in self ._get_param_names ()}
372372
373+ def set_params (self , ** params ):
374+ """Set the parameters of this estimator.
375+
376+ Parameters
377+ ----------
378+ **params : dict
379+ Step parameters.
380+
381+ Returns
382+ -------
383+ self : object
384+ Step class instance.
385+
386+ Notes
387+ -----
388+ Derived from [1]_.
389+
390+ References
391+ ----------
392+ .. [1] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
393+ """
394+ if not params :
395+ # Simple optimization to gain speed (inspect is slow)
396+ return self
397+
398+ valid_params = self ._get_param_names ()
399+
400+ for key , value in params .items ():
401+ if key not in valid_params :
402+ raise ValueError (
403+ f"Invalid parameter { key !r} for estimator { self } . "
404+ f"Valid parameters are: { valid_params !r} ."
405+ )
406+
407+ setattr (self , key , value )
408+
409+ return self
410+
373411 def __repr__ (self ) -> str :
374412 return pprint .pformat (self )
375413
@@ -453,7 +491,7 @@ def _name_estimators(estimators):
453491
454492class Recipe :
455493 def __init__ (self , * steps : Step ):
456- self .steps = steps
494+ self .steps = list ( steps )
457495 self ._output_format = "default"
458496
459497 def __repr__ (self ):
@@ -502,9 +540,69 @@ def get_params(self, deep=True) -> dict[str, Any]:
502540 out [f"{ name } __{ key } " ] = value
503541 return out
504542
505- def set_params (self , ** kwargs ):
506- if "steps" in kwargs :
507- self .steps = kwargs .get ("steps" )
543+ def set_params (self , ** params ):
544+ """Set the parameters of this estimator.
545+
546+ Valid parameter keys can be listed with ``get_params()``. Note that
547+ you can directly set the parameters of the estimators contained in
548+ `steps`.
549+
550+ Parameters
551+ ----------
552+ **params : dict
553+ Parameters of this estimator or parameters of estimators contained
554+ in `steps`. Parameters of the steps may be set using its name and
555+ the parameter name separated by a '__'.
556+
557+ Returns
558+ -------
559+ self : object
560+ Recipe class instance.
561+
562+ Notes
563+ -----
564+ Derived from [1]_ and [2]_.
565+
566+ References
567+ ----------
568+ .. [1] https://github.com/scikit-learn/scikit-learn/blob/ff1c6f3/sklearn/utils/metaestimators.py#L51-L70
569+ .. [2] https://github.com/scikit-learn/scikit-learn/blob/74016ab/sklearn/base.py#L214-L256
570+ """
571+ if not params :
572+ # Simple optimization to gain speed (inspect is slow)
573+ return self
574+
575+ # Ensure strict ordering of parameter setting:
576+ # 1. All steps
577+ if "steps" in params :
578+ self .steps = params .pop ("steps" )
579+
580+ # 2. Replace items with estimators in params
581+ estimator_name_indexes = {
582+ x : i for i , x in enumerate (name for name , _ in _name_estimators (self .steps ))
583+ }
584+ for name in list (params ):
585+ if "__" not in name and name in estimator_name_indexes :
586+ self .steps [estimator_name_indexes [name ]] = params .pop (name )
587+
588+ # 3. Step parameters and other initialisation arguments
589+ valid_params = self .get_params (deep = True )
590+
591+ nested_params = defaultdict (dict ) # grouped by prefix
592+ for key , value in params .items ():
593+ key , sub_key = key .split ("__" , maxsplit = 1 )
594+ if key not in valid_params :
595+ raise ValueError (
596+ f"Invalid parameter { key !r} for estimator { self } . "
597+ f"Valid parameters are: ['steps']."
598+ )
599+
600+ nested_params [key ][sub_key ] = value
601+
602+ for key , sub_params in nested_params .items ():
603+ valid_params [key ].set_params (** sub_params )
604+
605+ return self
508606
509607 def set_output (
510608 self ,
0 commit comments