Skip to content

Commit 0c8b88c

Browse files
committed
and again
1 parent 2ca9e9e commit 0c8b88c

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

tests/system/small/ml/test_forecasting.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def test_arima_plus_detect_anomalies_params(
434434
pd.testing.assert_frame_equal(
435435
anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]]
436436
.sort_values("anomaly_probability")
437-
.reset_index(),
438-
expected.sort_values("anomaly_probability").reset_index(),
437+
.reset_index(drop=True),
438+
expected.sort_values("anomaly_probability").reset_index(drop=True),
439439
rtol=0.1,
440440
check_index_type=False,
441441
check_dtype=False,
@@ -459,30 +459,28 @@ def test_arima_plus_score(
459459
)
460460
.to_pandas()
461461
.sort_values("id")
462-
.reset_index()
462+
.reset_index(drop=True)
463463
)
464464
else:
465465
result = time_series_arima_plus_model.score(
466466
new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]]
467467
).to_pandas()
468468
if id_col_name:
469-
expected = (
470-
pd.DataFrame(
471-
{
472-
"id": ["2", "1"],
473-
"mean_absolute_error": [120.011007, 120.011007],
474-
"mean_squared_error": [14562.562359, 14562.562359],
475-
"root_mean_squared_error": [120.675442, 120.675442],
476-
"mean_absolute_percentage_error": [4.80044, 4.80044],
477-
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
478-
},
479-
dtype="Float64",
480-
)
481-
.sort_values("id")
482-
.reset_index()
469+
expected = pd.DataFrame(
470+
{
471+
"id": ["2", "1"],
472+
"mean_absolute_error": [120.011007, 120.011007],
473+
"mean_squared_error": [14562.562359, 14562.562359],
474+
"root_mean_squared_error": [120.675442, 120.675442],
475+
"mean_absolute_percentage_error": [4.80044, 4.80044],
476+
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
477+
},
478+
dtype="Float64",
483479
)
484480
expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True)
485481
expected["id"] = expected["id"].astype("string[pyarrow]")
482+
expected = expected.sort_values("id")
483+
expected = expected.reset_index(drop=True)
486484
else:
487485
expected = pd.DataFrame(
488486
{
@@ -562,30 +560,28 @@ def test_arima_plus_score_series(
562560
)
563561
.to_pandas()
564562
.sort_values("id")
565-
.reset_index()
563+
.reset_index(drop=True)
566564
)
567565
else:
568566
result = time_series_arima_plus_model.score(
569567
new_time_series_df["parsed_date"], new_time_series_df["total_visits"]
570568
).to_pandas()
571569
if id_col_name:
572-
expected = (
573-
pd.DataFrame(
574-
{
575-
"id": ["2", "1"],
576-
"mean_absolute_error": [120.011007, 120.011007],
577-
"mean_squared_error": [14562.562359, 14562.562359],
578-
"root_mean_squared_error": [120.675442, 120.675442],
579-
"mean_absolute_percentage_error": [4.80044, 4.80044],
580-
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
581-
},
582-
dtype="Float64",
583-
)
584-
.sort_values("id")
585-
.reset_index()
570+
expected = pd.DataFrame(
571+
{
572+
"id": ["2", "1"],
573+
"mean_absolute_error": [120.011007, 120.011007],
574+
"mean_squared_error": [14562.562359, 14562.562359],
575+
"root_mean_squared_error": [120.675442, 120.675442],
576+
"mean_absolute_percentage_error": [4.80044, 4.80044],
577+
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
578+
},
579+
dtype="Float64",
586580
)
587581
expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True)
588582
expected["id"] = expected["id"].astype("string[pyarrow]")
583+
expected = expected.sort_values("id")
584+
expected = expected.reset_index(drop=True)
589585
else:
590586
expected = pd.DataFrame(
591587
{
@@ -602,6 +598,7 @@ def test_arima_plus_score_series(
602598
expected,
603599
rtol=0.1,
604600
check_index_type=False,
601+
check_dtype=False,
605602
)
606603

607604

0 commit comments

Comments
 (0)