Skip to content

Commit c1cba63

Browse files
committed
Remove return statement to setters and document default values in docstring
1 parent 8fbb252 commit c1cba63

File tree

1 file changed

+8
-45
lines changed

1 file changed

+8
-45
lines changed

dataikuapi/dss/ml.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,6 @@ def set_grid_search(self, shuffle=True, seed=0):
404404
:type shuffle: bool
405405
:param seed:
406406
:type seed: int
407-
:return: current (mutated) settings
408-
:rtype: HyperparameterSearchSettings
409-
410407
"""
411408
self._raw_settings["strategy"] = "GRID"
412409
if shuffle is not None:
@@ -416,31 +413,24 @@ def set_grid_search(self, shuffle=True, seed=0):
416413
else:
417414
self._raw_settings["randomized"] = shuffle
418415
self._set_seed(seed)
419-
return self
420416

421417
def set_random_search(self, seed=0):
422418
"""
423419
Sets the search strategy to "RANDOM" to perform a random search on the hyperparameters.
424-
:param seed:
420+
:param seed: defaults to 0
425421
:type seed: int
426-
:return: current (mutated) settings
427-
:rtype: HyperparameterSearchSettings
428422
"""
429423
self._raw_settings["strategy"] = "RANDOM"
430424
self._set_seed(seed)
431-
return self
432425

433426
def set_bayesian_search(self, seed=0):
434427
"""
435428
Sets the search strategy to "BAYESIAN" to perform a Bayesian search on the hyperparameters.
436-
:param seed:
429+
:param seed: defaults to 0
437430
:type seed: int
438-
:return: current (mutated) settings
439-
:rtype: HyperparameterSearchSettings
440431
"""
441432
self._raw_settings["strategy"] = "BAYESIAN"
442433
self._set_seed(seed)
443-
return self
444434

445435
@property
446436
def validation_mode(self):
@@ -463,12 +453,10 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
463453
"""
464454
Sets the validation mode to k-fold cross-validation (either "KFOLD" or "TIME_SERIES_KFOLD" if time-based ordering
465455
is enabled).
466-
:param n_folds: the number of folds used for the hyperparameter search
456+
:param n_folds: the number of folds used for the hyperparameter search, defaults to 5
467457
:type n_folds: int
468-
:param stratified: if True, keep the same proportion of each target classes in all folds
458+
:param stratified: if True, keep the same proportion of each target classes in all folds, defaults to True
469459
:type stratified: bool
470-
:return: current (mutated) settings
471-
:rtype: HyperparameterSearchSettings
472460
"""
473461
if self._raw_settings["mode"] == "TIME_SERIES_SINGLE_SPLIT":
474462
self._raw_settings["mode"] = "TIME_SERIES_KFOLD"
@@ -486,18 +474,15 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
486474
warnings.warn("stratified must be a boolean")
487475
else:
488476
self._raw_settings["stratified"] = stratified
489-
return self
490477

491478
def set_single_split_validation(self, split_ratio=0.8, stratified=True):
492479
"""
493480
Sets the validation mode to single split (either "SHUFFLE" or "TIME_SERIES_SINGLE_SPLIT" if time-based ordering
494481
is enabled).
495-
:param split_ratio: ratio of the data used for the train during hyperparameter search
482+
:param split_ratio: ratio of the data used for the train during hyperparameter search, defaults to 0.8
496483
:type split_ratio: float
497-
:param stratified: if True, keep the same proportion of each target classes in both splits
484+
:param stratified: if True, keep the same proportion of each target classes in both splits, defaults to True
498485
:type stratified: bool
499-
:return: current (mutated) settings
500-
:rtype: HyperparameterSearchSettings
501486
"""
502487
if self._raw_settings["mode"] == "TIME_SERIES_KFOLD":
503488
self._raw_settings["mode"] = "TIME_SERIES_SINGLE_SPLIT"
@@ -515,15 +500,12 @@ def set_single_split_validation(self, split_ratio=0.8, stratified=True):
515500
warnings.warn("stratified must be a boolean")
516501
else:
517502
self._raw_settings["stratified"] = stratified
518-
return self
519503

520504
def set_custom_validation(self, code=None):
521505
"""
522506
Sets the validation mode to "CUSTOM".
523507
:param code: definition of the validation
524508
:type code: str
525-
:return: current (mutated) settings
526-
:rtype: HyperparameterSearchSettings
527509
"""
528510
self._raw_settings["mode"] = "CUSTOM"
529511
if code is not None:
@@ -532,25 +514,21 @@ def set_custom_validation(self, code=None):
532514
warnings.warn("code must be a Python interpretable string")
533515
else:
534516
self._raw_settings["code"] = code
535-
return self
536517

537518
def set_search_distribution(self, distributed=False, n_containers=4):
538519
"""
539520
Sets the distribution parameters for the hyperparameter search execution.
540521
:param distributed: if True, distribute search across n_containers containers in the Kubernetes
541-
cluster selected in containerized execution configuration of the runtime environment
522+
cluster selected in containerized execution configuration of the runtime environment, defaults to False
542523
:type distributed: bool
543-
:param n_containers: number of containers to use for the distributed search
524+
:param n_containers: number of containers to use for the distributed search, defaults to 4
544525
:type n_containers: int
545-
:return: current (mutated) settings
546-
:rtype: HyperparameterSearchSettings
547526
"""
548527
assert isinstance(distributed, bool)
549528
if n_containers is not None:
550529
assert isinstance(n_containers, int)
551530
self._raw_settings["nContainers"] = n_containers
552531
self._raw_settings["distributed"] = distributed
553-
return self
554532

555533
@property
556534
def distributed(self):
@@ -669,12 +647,9 @@ def set_explicit_values(self, values):
669647
- the definition mode of the current numerical hyperparameter to "EXPLICIT"
670648
:param values: the explicit list of numerical values considered for this hyperparameter in the search
671649
:type values: list of float | int
672-
:return: current (mutated) settings
673-
:rtype: NumericalHyperparameterSettings
674650
"""
675651
self.values = values
676652
self.definition_mode = "EXPLICIT"
677-
return self
678653

679654
@property
680655
def values(self):
@@ -746,12 +721,9 @@ def set_range(self, min=None, max=None, nb_values=None):
746721
:type max: float | int
747722
:param nb_values: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
748723
:type nb_values: int
749-
:return: current (mutated) settings
750-
:rtype: NumericalHyperparameterSettings
751724
"""
752725
self._set_range(min=min, max=max, nb_values=nb_values)
753726
self.definition_mode = "RANGE"
754-
return self
755727

756728
@property
757729
def range(self):
@@ -831,8 +803,6 @@ def set_values(self, values):
831803
Enables the search over listed values (categories).
832804
:param values: values to enable, all other values will be disabled
833805
:type values: list of str
834-
:return: current (mutated) settings
835-
:rtype: CategoricalHyperparameterSettings
836806
"""
837807
all_possible_values = self.get_all_possible_values()
838808
for category in values:
@@ -846,7 +816,6 @@ def set_values(self, values):
846816
self._algo_settings[self.name]["values"][category] = {"enabled": True}
847817
else:
848818
self._algo_settings[self.name]["values"][category] = {"enabled": False}
849-
return self
850819

851820
def get_values(self):
852821
"""
@@ -881,13 +850,10 @@ def set_value(self, value):
881850
"""
882851
:param value:
883852
:type value: bool | int | float
884-
:return: current (mutated) settings
885-
:rtype: SingleValueHyperparameterSettings
886853
"""
887854
if self.accepted_types is not None:
888855
assert any(isinstance(value, accepted_type) for accepted_type in self.accepted_types), "Invalid type for hyperparameter {}. Type must be one of: {}".format(self.name, self.accepted_types)
889856
self._algo_settings[self.name] = value
890-
return self
891857

892858
def get_value(self):
893859
"""
@@ -925,13 +891,10 @@ def set_value(self, value):
925891
"""
926892
:param value:
927893
:type value: str
928-
:return: current (mutated) settings
929-
:rtype: SingleValueHyperparameterSettings
930894
"""
931895
if self.accepted_values is not None:
932896
assert value in self.accepted_values, "Invalid value for hyperparameter {}. Must be in {}".format(self.name, json.dumps(self.accepted_values))
933897
self._algo_settings[self.name] = value
934-
return self
935898

936899
def get_value(self):
937900
"""

0 commit comments

Comments
 (0)