@@ -451,26 +451,35 @@ def test_arima_plus_score(
451451 id_col_name ,
452452):
453453 if id_col_name :
454- result = time_series_arima_plus_model_w_id .score (
455- new_time_series_df_w_id [["parsed_date" ]],
456- new_time_series_df_w_id [["total_visits" ]],
457- new_time_series_df_w_id [["id" ]],
458- ).to_pandas ()
454+ result = (
455+ time_series_arima_plus_model_w_id .score (
456+ new_time_series_df_w_id [["parsed_date" ]],
457+ new_time_series_df_w_id [["total_visits" ]],
458+ new_time_series_df_w_id [["id" ]],
459+ )
460+ .to_pandas ()
461+ .sort_values ("id" )
462+ .reset_index ()
463+ )
459464 else :
460465 result = time_series_arima_plus_model .score (
461466 new_time_series_df [["parsed_date" ]], new_time_series_df [["total_visits" ]]
462467 ).to_pandas ()
463468 if id_col_name :
464- expected = pd .DataFrame (
465- {
466- "id" : ["2" , "1" ],
467- "mean_absolute_error" : [120.011007 , 120.011007 ],
468- "mean_squared_error" : [14562.562359 , 14562.562359 ],
469- "root_mean_squared_error" : [120.675442 , 120.675442 ],
470- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
471- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
472- },
473- dtype = "Float64" ,
469+ expected = (
470+ pd .DataFrame (
471+ {
472+ "id" : ["2" , "1" ],
473+ "mean_absolute_error" : [120.011007 , 120.011007 ],
474+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
475+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
476+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
477+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
478+ },
479+ dtype = "Float64" ,
480+ )
481+ .sort_values ("id" )
482+ .reset_index ()
474483 )
475484 expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
476485 expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
@@ -486,8 +495,8 @@ def test_arima_plus_score(
486495 dtype = "Float64" ,
487496 )
488497 pd .testing .assert_frame_equal (
489- result . sort_values ( "id" ). reset_index () ,
490- expected . sort_values ( "id" ). reset_index () ,
498+ result ,
499+ expected ,
491500 rtol = 0.1 ,
492501 check_index_type = False ,
493502 check_dtype = False ,
@@ -545,26 +554,35 @@ def test_arima_plus_score_series(
545554 id_col_name ,
546555):
547556 if id_col_name :
548- result = time_series_arima_plus_model_w_id .score (
549- new_time_series_df_w_id ["parsed_date" ],
550- new_time_series_df_w_id ["total_visits" ],
551- new_time_series_df_w_id ["id" ],
552- ).to_pandas ()
557+ result = (
558+ time_series_arima_plus_model_w_id .score (
559+ new_time_series_df_w_id ["parsed_date" ],
560+ new_time_series_df_w_id ["total_visits" ],
561+ new_time_series_df_w_id ["id" ],
562+ )
563+ .to_pandas ()
564+ .sort_values ("id" )
565+ .reset_index ()
566+ )
553567 else :
554568 result = time_series_arima_plus_model .score (
555569 new_time_series_df ["parsed_date" ], new_time_series_df ["total_visits" ]
556570 ).to_pandas ()
557571 if id_col_name :
558- expected = pd .DataFrame (
559- {
560- "id" : ["2" , "1" ],
561- "mean_absolute_error" : [120.011007 , 120.011007 ],
562- "mean_squared_error" : [14562.562359 , 14562.562359 ],
563- "root_mean_squared_error" : [120.675442 , 120.675442 ],
564- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
565- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
566- },
567- dtype = "Float64" ,
572+ expected = (
573+ pd .DataFrame (
574+ {
575+ "id" : ["2" , "1" ],
576+ "mean_absolute_error" : [120.011007 , 120.011007 ],
577+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
578+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
579+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
580+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
581+ },
582+ dtype = "Float64" ,
583+ )
584+ .sort_values ("id" )
585+ .reset_index ()
568586 )
569587 expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
570588 expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
@@ -580,11 +598,10 @@ def test_arima_plus_score_series(
580598 dtype = "Float64" ,
581599 )
582600 pd .testing .assert_frame_equal (
583- result . sort_values ( "id" ). reset_index () ,
584- expected . sort_values ( "id" ). reset_index () ,
601+ result ,
602+ expected ,
585603 rtol = 0.1 ,
586604 check_index_type = False ,
587- check_dtype = False ,
588605 )
589606
590607
0 commit comments