@@ -434,8 +434,8 @@ def test_arima_plus_detect_anomalies_params(
434434 pd .testing .assert_frame_equal (
435435 anomalies [["is_anomaly" , "lower_bound" , "upper_bound" , "anomaly_probability" ]]
436436 .sort_values ("anomaly_probability" )
437- .reset_index (),
438- expected .sort_values ("anomaly_probability" ).reset_index (),
437+ .reset_index (drop = True ),
438+ expected .sort_values ("anomaly_probability" ).reset_index (drop = True ),
439439 rtol = 0.1 ,
440440 check_index_type = False ,
441441 check_dtype = False ,
@@ -459,30 +459,28 @@ def test_arima_plus_score(
459459 )
460460 .to_pandas ()
461461 .sort_values ("id" )
462- .reset_index ()
462+ .reset_index (drop = True )
463463 )
464464 else :
465465 result = time_series_arima_plus_model .score (
466466 new_time_series_df [["parsed_date" ]], new_time_series_df [["total_visits" ]]
467467 ).to_pandas ()
468468 if id_col_name :
469- expected = (
470- pd .DataFrame (
471- {
472- "id" : ["2" , "1" ],
473- "mean_absolute_error" : [120.011007 , 120.011007 ],
474- "mean_squared_error" : [14562.562359 , 14562.562359 ],
475- "root_mean_squared_error" : [120.675442 , 120.675442 ],
476- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
477- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
478- },
479- dtype = "Float64" ,
480- )
481- .sort_values ("id" )
482- .reset_index ()
469+ expected = pd .DataFrame (
470+ {
471+ "id" : ["2" , "1" ],
472+ "mean_absolute_error" : [120.011007 , 120.011007 ],
473+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
474+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
475+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
476+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
477+ },
478+ dtype = "Float64" ,
483479 )
484480 expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
485481 expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
482+ expected = expected .sort_values ("id" )
483+ expected = expected .reset_index (drop = True )
486484 else :
487485 expected = pd .DataFrame (
488486 {
@@ -562,30 +560,28 @@ def test_arima_plus_score_series(
562560 )
563561 .to_pandas ()
564562 .sort_values ("id" )
565- .reset_index ()
563+ .reset_index (drop = True )
566564 )
567565 else :
568566 result = time_series_arima_plus_model .score (
569567 new_time_series_df ["parsed_date" ], new_time_series_df ["total_visits" ]
570568 ).to_pandas ()
571569 if id_col_name :
572- expected = (
573- pd .DataFrame (
574- {
575- "id" : ["2" , "1" ],
576- "mean_absolute_error" : [120.011007 , 120.011007 ],
577- "mean_squared_error" : [14562.562359 , 14562.562359 ],
578- "root_mean_squared_error" : [120.675442 , 120.675442 ],
579- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
580- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
581- },
582- dtype = "Float64" ,
583- )
584- .sort_values ("id" )
585- .reset_index ()
570+ expected = pd .DataFrame (
571+ {
572+ "id" : ["2" , "1" ],
573+ "mean_absolute_error" : [120.011007 , 120.011007 ],
574+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
575+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
576+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
577+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
578+ },
579+ dtype = "Float64" ,
586580 )
587581 expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
588582 expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
583+ expected = expected .sort_values ("id" )
584+ expected = expected .reset_index (drop = True )
589585 else :
590586 expected = pd .DataFrame (
591587 {
@@ -602,6 +598,7 @@ def test_arima_plus_score_series(
602598 expected ,
603599 rtol = 0.1 ,
604600 check_index_type = False ,
601+ check_dtype = False ,
605602 )
606603
607604
0 commit comments