@@ -251,16 +251,15 @@ def use_feature(self, feature_name):
251251
252252 def use_sample_weighting (self , feature_name ):
253253 """
254- Uses a feature as sample weight
255- :param str feature_name: Name of the feature to use
254+ Deprecated. Will be removed from DSSMLTaskSettings class
256255 """
257- raise NotImplementedError ("use_sample_weighting not available for class {}" .format (self .__class__ ))
256+ raise NotImplementedError ("use_sample_weighting() not available for class {}" .format (self .__class__ ))
258257
259258 def remove_sample_weighting (self ):
260259 """
261- Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
260+ Deprecated. Will be removed from DSSMLTaskSettings class
262261 """
263- raise NotImplementedError ("remove_sample_weighting not available for class {}" .format (self .__class__ ))
262+ raise NotImplementedError ("remove_sample_weighting() not available for class {}" .format (self .__class__ ))
264263
265264 def get_algorithm_settings (self , algorithm_name ):
266265 """
@@ -387,6 +386,22 @@ class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
387386 "KERAS_CODE" : "keras"
388387 }
389388
389+ class PredictionTypes :
390+ BINARY = "BINARY_CLASSIFICATION"
391+ REGRESSION = "REGRESSION"
392+ MULTICLASS = "MULTICLASS"
393+
394+ def __init__ (self , client , project_key , analysis_id , mltask_id , mltask_settings ):
395+ DSSMLTaskSettings .__init__ (self , client , project_key , analysis_id , mltask_id , mltask_settings )
396+
397+ if self .get_prediction_type () not in [self .PredictionTypes .BINARY , self .PredictionTypes .REGRESSION , self .PredictionTypes .MULTICLASS ]:
398+ raise ValueError ("Unknown prediction type: {}" .format (self .prediction_type ))
399+
400+ self .classification_prediction_types = [self .PredictionTypes .BINARY , self .PredictionTypes .MULTICLASS ]
401+
402+ def get_prediction_type (self ):
403+ return self .mltask_settings ['predictionType' ]
404+
390405 @property
391406 def split_params (self ):
392407 """
@@ -416,7 +431,7 @@ def split_ordered_by(self, feature_name, ascending=True):
416431
417432 :rtype: self
418433 """
419- warnings .warn ("split_ordered_by is deprecated, please use split_params.set_order_by() instead" , DeprecationWarning )
434+ warnings .warn ("split_ordered_by() is deprecated, please use split_params.set_order_by() instead" , DeprecationWarning )
420435 self .split_params .set_order_by (feature_name , ascending = True )
421436
422437 return self
@@ -427,34 +442,77 @@ def remove_ordered_split(self):
427442
428443 :rtype: self
429444 """
430- warnings .warn ("remove_ordered_split is deprecated, please use split_params.unset_order_by() instead" , DeprecationWarning )
445+ warnings .warn ("remove_ordered_split() is deprecated, please use split_params.unset_order_by() instead" , DeprecationWarning )
431446 self .split_params .unset_order_by ()
432447
433448 return self
434449
435450 def use_sample_weighting (self , feature_name ):
451+ """
452+ Deprecated. use set_weighting()
453+ """
454+ warnings .warn ("use_sample_weighting() is deprecated, please use set_weighting() instead" , DeprecationWarning )
455+ return self .set_weighting (method = 'SAMPLE_WEIGHT' , feature_name = feature_name , )
456+
457+ def set_weighting (self , method , feature_name = None ):
436458 """
437459 Uses a feature as sample weight
438460 :param str feature_name: Name of the feature to use
439461 """
440- if not feature_name in self .mltask_settings ["preprocessing" ]["per_feature" ]:
441- raise ValueError ("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name )
462+ self .unset_weighting ()
442463
443- self .remove_sample_weighting ()
444-
445- self .mltask_settings ['weight' ]['weightMethod' ] = 'SAMPLE_WEIGHT'
446- self .mltask_settings ['weight' ]['sampleWeightVariable' ] = feature_name
447- self .mltask_settings ['preprocessing' ]['per_feature' ][feature_name ]['role' ] = 'WEIGHT'
464+ if method == "NO_WEIGHTING" :
465+ self .mltask_settings ['weight' ]['weightMethod' ] = method
466+
467+ elif method == "SAMPLE_WEIGHT" :
468+ if not feature_name in self .mltask_settings ["preprocessing" ]["per_feature" ]:
469+ raise ValueError ("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name )
470+
471+ self .mltask_settings ['weight' ]['weightMethod' ] = method
472+ self .mltask_settings ['weight' ]['sampleWeightVariable' ] = feature_name
473+ self .mltask_settings ['preprocessing' ]['per_feature' ][feature_name ]['role' ] = 'WEIGHT'
474+
475+ elif method == "CLASS_WEIGHT" :
476+ if self .get_prediction_type () not in self .classification_prediction_types :
477+ raise ValueError ("Weighting method: {} not compatible with prediction type: {}, should be in {}" .format (method , self .get_prediction_type (), self .classification_prediction_types ))
478+
479+ self .mltask_settings ['weight' ]['weightMethod' ] = method
480+
481+ elif method == "CLASS_AND_SAMPLE_WEIGHT" :
482+ if self .get_prediction_type () not in self .classification_prediction_types :
483+ raise ValueError ("Weighting method: {} not compatible with prediction type: {}, should be in {}" .format (method , self .get_prediction_type (), self .classification_prediction_types ))
484+ if not feature_name in self .mltask_settings ["preprocessing" ]["per_feature" ]:
485+ raise ValueError ("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name )
486+
487+ self .mltask_settings ['weight' ]['weightMethod' ] = method
488+ self .mltask_settings ['weight' ]['sampleWeightVariable' ] = feature_name
489+ self .mltask_settings ['preprocessing' ]['per_feature' ][feature_name ]['role' ] = 'WEIGHT'
490+
491+ else :
492+ raise ValueError ("Unknown weighting method: {}" .format (method ))
493+
494+ return self
448495
449496 def remove_sample_weighting (self ):
497+ """
498+ Deprecated. Use unset_weighting() instead
499+ """
500+ warnings .warn ("remove_sample_weighting() is deprecated, please use unset_weighting() instead" , DeprecationWarning )
501+ return self .unset_weighting ()
502+
503+ def unset_weighting (self ):
450504 """
451505 Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
506+
507+ :rtype: self
452508 """
453509 self .mltask_settings ['weight' ]['weightMethod' ] = 'NO_WEIGHTING'
454510 for feature_name in self .mltask_settings ['preprocessing' ]['per_feature' ]:
455511 if self .mltask_settings ['preprocessing' ]['per_feature' ][feature_name ]['role' ] == 'WEIGHT' :
456512 self .mltask_settings ['preprocessing' ]['per_feature' ][feature_name ]['role' ] = 'INPUT'
457513
514+ return self
515+
458516
459517class DSSClusteringMLTaskSettings (DSSMLTaskSettings ):
460518 __doc__ = []
0 commit comments