@@ -74,7 +74,7 @@ def get_estimator(library_name: str, estimator_name: str):
7474def get_estimator_methods (bench_case : BenchCase ) -> Dict [str , List [str ]]:
7575 # default estimator methods
7676 estimator_methods = {
77- "training" : ["fit" ],
77+ "training" : ["partial_fit" , " fit" ],
7878 "inference" : ["predict" , "predict_proba" , "transform" ],
7979 }
8080 for stage in estimator_methods .keys ():
@@ -337,34 +337,43 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
337337 return acceleration_lines > 0 and fallback_lines == 0
338338
339339
340- def create_online_function (method_instance , data_args , batch_size ):
341- n_batches = data_args [0 ].shape [0 ] // batch_size
340+ def create_online_function (
341+ estimator_instance , method_instance , data_args , num_batches , batch_size
342+ ):
342343
343344 if "y" in list (inspect .signature (method_instance ).parameters ):
344345
345346 def ndarray_function (x , y ):
346- for i in range (n_batches ):
347+ for i in range (num_batches ):
347348 method_instance (
348349 x [i * batch_size : (i + 1 ) * batch_size ],
349350 y [i * batch_size : (i + 1 ) * batch_size ],
350351 )
352+ if hasattr (estimator_instance , "_onedal_finalize_fit" ):
353+ estimator_instance ._onedal_finalize_fit ()
351354
352355 def dataframe_function (x , y ):
353- for i in range (n_batches ):
356+ for i in range (num_batches ):
354357 method_instance (
355358 x .iloc [i * batch_size : (i + 1 ) * batch_size ],
356359 y .iloc [i * batch_size : (i + 1 ) * batch_size ],
357360 )
361+ if hasattr (estimator_instance , "_onedal_finalize_fit" ):
362+ estimator_instance ._onedal_finalize_fit ()
358363
359364 else :
360365
361366 def ndarray_function (x ):
362- for i in range (n_batches ):
367+ for i in range (num_batches ):
363368 method_instance (x [i * batch_size : (i + 1 ) * batch_size ])
369+ if hasattr (estimator_instance , "_onedal_finalize_fit" ):
370+ estimator_instance ._onedal_finalize_fit ()
364371
365372 def dataframe_function (x ):
366- for i in range (n_batches ):
373+ for i in range (num_batches ):
367374 method_instance (x .iloc [i * batch_size : (i + 1 ) * batch_size ])
375+ if hasattr (estimator_instance , "_onedal_finalize_fit" ):
376+ estimator_instance ._onedal_finalize_fit ()
368377
369378 if "ndarray" in str (type (data_args [0 ])):
370379 return ndarray_function
@@ -417,12 +426,28 @@ def measure_sklearn_estimator(
417426 data_args = (x_train ,)
418427 else :
419428 data_args = (x_test ,)
420- batch_size = get_bench_case_value (
421- bench_case , f"algorithm:batch_size:{ stage } "
422- )
423- if batch_size is not None :
429+
430+ if method == "partial_fit" :
431+ num_batches = get_bench_case_value (bench_case , "data:num_batches" )
432+ batch_size = get_bench_case_value (bench_case , "data:batch_size" )
433+
434+ if batch_size is None :
435+ if num_batches is None :
436+ num_batches = 5
437+ batch_size = (
438+ data_args [0 ].shape [0 ] + num_batches - 1
439+ ) // num_batches
440+ if num_batches is None :
441+ num_batches = (
442+ data_args [0 ].shape [0 ] + batch_size - 1
443+ ) // batch_size
444+
424445 method_instance = create_online_function (
425- method_instance , data_args , batch_size
446+ estimator_instance ,
447+ method_instance ,
448+ data_args ,
449+ num_batches ,
450+ batch_size ,
426451 )
427452 # daal4py model builders enabling branch
428453 if enable_modelbuilders and stage == "inference" :
@@ -440,10 +465,6 @@ def measure_sklearn_estimator(
440465 metrics [method ]["box filter mean[ms]" ],
441466 metrics [method ]["box filter std[ms]" ],
442467 ) = measure_case (bench_case , method_instance , * data_args )
443- if batch_size is not None :
444- metrics [method ]["throughput[samples/ms]" ] = (
445- (data_args [0 ].shape [0 ] // batch_size ) * batch_size
446- ) / metrics [method ]["time[ms]" ]
447468 if ensure_sklearnex_patching :
448469 full_method_name = f"{ estimator_class .__name__ } .{ method } "
449470 sklearnex_logging_stream .seek (0 )
0 commit comments