Skip to content

Commit 51b336b

Browse files
committed
minor update
1 parent bb691b0 commit 51b336b

File tree

3 files changed

+39
-73
lines changed

3 files changed

+39
-73
lines changed

bigframes/ml/forecasting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
from typing import List, Optional
20+
import warnings
2021

2122
from google.cloud import bigquery
2223

@@ -234,8 +235,6 @@ def _fit(
234235
if self.data_frequency in ["hourly", "per_minute"]:
235236
timestamp_col = X.columns[0]
236237
if "date" in X[timestamp_col].dtype.name:
237-
import warnings
238-
239238
warnings.warn(
240239
f"Converting Date column '{timestamp_col}' to datetime for "
241240
f"{self.data_frequency} frequency. This is required because "

notebooks/ml/timeseries_analysis.ipynb

Lines changed: 35 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"outputs": [],
1919
"source": [
2020
"import bigframes.pandas as bpd\n",
21+
"from bigframes.ml import forecasting\n",
2122
"bpd.options.display.repr_mode = \"anywidget\""
2223
]
2324
},
@@ -38,15 +39,10 @@
3839
"metadata": {},
3940
"outputs": [],
4041
"source": [
41-
"# Load the bikeshare dataset from the public BigQuery repository.\n",
4242
"df = bpd.read_gbq(\"bigquery-public-data.san_francisco_bikeshare.bikeshare_trips\")\n",
43-
"\n",
44-
"# Filter the data to focus on a specific time period and user type.\n",
4543
"df = df[df[\"start_date\"] >= \"2018-01-01\"]\n",
4644
"df = df[df[\"subscriber_type\"] == \"Subscriber\"]\n",
47-
"\n",
48-
"# Resample the data to an hourly frequency by counting the number of trips in each hour.\n",
49-
"df[\"trip_hour\"] = df[\"start_date\"] .dt.floor(\"h\")\n",
45+
"df[\"trip_hour\"] = df[\"start_date\"].dt.floor(\"h\")\n",
5046
"df_grouped = df[[\"trip_hour\", \"trip_id\"]].groupby(\"trip_hour\").count().reset_index()\n",
5147
"df_grouped = df_grouped.rename(columns={\"trip_id\": \"num_trips\"})"
5248
]
@@ -80,7 +76,7 @@
8076
"data": {
8177
"text/html": [
8278
"✅ Completed. \n",
83-
" Query processed 58.7 MB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:8904286c-644a-409d-9ba0-fe308a3382bf&page=queryresults\">Job bigframes-dev:US.8904286c-644a-409d-9ba0-fe308a3382bf details</a>]\n",
79+
" Query processed 58.7 MB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:b91a9e6f-d00a-44f6-afe1-255e25945a1d&page=queryresults\">Job bigframes-dev:US.b91a9e6f-d00a-44f6-afe1-255e25945a1d details</a>]\n",
8480
" "
8581
],
8682
"text/plain": [
@@ -94,7 +90,7 @@
9490
"data": {
9591
"text/html": [
9692
"✅ Completed. \n",
97-
" Query processed 7.1 kB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:f859016e-cf03-4581-b176-e50c559f9380&page=queryresults\">Job bigframes-dev:US.f859016e-cf03-4581-b176-e50c559f9380 details</a>]\n",
93+
" Query processed 7.1 kB in 9 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:1f3bfd8b-5740-4895-be14-6a8b92a4f3b1&page=queryresults\">Job bigframes-dev:US.1f3bfd8b-5740-4895-be14-6a8b92a4f3b1 details</a>]\n",
9894
" "
9995
],
10096
"text/plain": [
@@ -135,12 +131,12 @@
135131
{
136132
"data": {
137133
"application/vnd.jupyter.widget-view+json": {
138-
"model_id": "f53e2ae48801458ab42facd6ebe10728",
134+
"model_id": "8a4d64e6cf4844018c1b8593d1d99e05",
139135
"version_major": 2,
140136
"version_minor": 1
141137
},
142138
"text/plain": [
143-
"<bigframes.display.anywidget.TableWidget object at 0x7f62f2b13380>"
139+
"<bigframes.display.anywidget.TableWidget object at 0x7efcefb3bb60>"
144140
]
145141
},
146142
"metadata": {},
@@ -159,10 +155,6 @@
159155
}
160156
],
161157
"source": [
162-
"# Use the TimesFM model to forecast the last 168 hours (one week).\n",
163-
"# The `timestamp_column` specifies the time index of the series.\n",
164-
"# The `data_column` is the value we want to forecast.\n",
165-
"# The `horizon` defines how many steps into the future to predict.\n",
166158
"result = df_grouped.head(2842-168).ai.forecast(\n",
167159
" timestamp_column=\"trip_hour\",\n",
168160
" data_column=\"num_trips\",\n",
@@ -191,7 +183,7 @@
191183
"data": {
192184
"text/html": [
193185
"\n",
194-
" Query processed 1.8 MB in 40 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:02689ad8-e003-4911-acfc-be2e5c75652d&page=queryresults\">Job bigframes-dev:US.02689ad8-e003-4911-acfc-be2e5c75652d details</a>]\n",
186+
" Query processed 1.8 MB in 47 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:36efa98e-2843-4bc9-8225-06875236ef17&page=queryresults\">Job bigframes-dev:US.36efa98e-2843-4bc9-8225-06875236ef17 details</a>]\n",
195187
" "
196188
],
197189
"text/plain": [
@@ -205,7 +197,7 @@
205197
"data": {
206198
"text/html": [
207199
"✅ Completed. \n",
208-
" Query processed 92.2 kB in a moment of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:9e29715c-0d9e-40db-8ead-d063791e61a5&page=queryresults\">Job bigframes-dev:US.9e29715c-0d9e-40db-8ead-d063791e61a5 details</a>]\n",
200+
" Query processed 92.2 kB in a moment of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:18805b62-a8e2-4c69-a5bf-97aa19df8095&page=queryresults\">Job bigframes-dev:US.18805b62-a8e2-4c69-a5bf-97aa19df8095 details</a>]\n",
209201
" "
210202
],
211203
"text/plain": [
@@ -233,7 +225,7 @@
233225
"data": {
234226
"text/html": [
235227
"✅ Completed. \n",
236-
" Query processed 10.8 kB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:bf966f53-b73a-42af-8cb7-5d2384620e3d&page=queryresults\">Job bigframes-dev:US.bf966f53-b73a-42af-8cb7-5d2384620e3d details</a>]\n",
228+
" Query processed 10.8 kB in 11 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:a2b15286-6c7f-40a7-8009-045e5e4f3dbf&page=queryresults\">Job bigframes-dev:US.a2b15286-6c7f-40a7-8009-045e5e4f3dbf details</a>]\n",
237229
" "
238230
],
239231
"text/plain": [
@@ -274,12 +266,12 @@
274266
{
275267
"data": {
276268
"application/vnd.jupyter.widget-view+json": {
277-
"model_id": "a5360418703842cf9670f9d6ae77d8ee",
269+
"model_id": "b3816017ab4440c7bcf258df4a0ceff8",
278270
"version_major": 2,
279271
"version_minor": 1
280272
},
281273
"text/plain": [
282-
"<bigframes.display.anywidget.TableWidget object at 0x7f62f011fc50>"
274+
"<bigframes.display.anywidget.TableWidget object at 0x7efcec227c50>"
283275
]
284276
},
285277
"metadata": {},
@@ -298,27 +290,16 @@
298290
}
299291
],
300292
"source": [
301-
"from bigframes.ml import forecasting\n",
302-
"\n",
303-
"# Create and configure an ARIMAPlus model for hourly data.\n",
304-
"# `auto_arima_max_order` is set to a lower value to reduce the training time.\n",
305-
"# `data_frequency` is set to 'hourly' to match our aggregated data.\n",
306293
"model = forecasting.ARIMAPlus(\n",
307294
" auto_arima_max_order=5, # Reduce runtime for large datasets\n",
308295
" data_frequency=\"hourly\",\n",
309296
" horizon=168\n",
310297
")\n",
311-
"\n",
312-
"# Prepare the training data by excluding the last week.\n",
313-
"X = df_grouped.head(2842-168)[[\"trip_hour\"] ]\n",
314-
"y = df_grouped.head(2842-168)[[\"num_trips\"] ]\n",
315-
"\n",
316-
"# Fit the model to the training data.\n",
298+
"X = df_grouped.head(2842-168)[[\"trip_hour\"]]\n",
299+
"y = df_grouped.head(2842-168)[[\"num_trips\"]]\n",
317300
"model.fit(\n",
318301
" X, y\n",
319302
")\n",
320-
"\n",
321-
"# Generate predictions for the specified horizon.\n",
322303
"predictions = model.predict(horizon=168, confidence_level=0.95)\n",
323304
"predictions"
324305
]
@@ -343,7 +324,7 @@
343324
"data": {
344325
"text/html": [
345326
"✅ Completed. \n",
346-
" Query processed 31.7 MB in 9 seconds of slot time.\n",
327+
" Query processed 31.7 MB in 10 seconds of slot time.\n",
347328
" "
348329
],
349330
"text/plain": [
@@ -357,7 +338,7 @@
357338
"data": {
358339
"text/html": [
359340
"✅ Completed. \n",
360-
" Query processed 58.8 MB in 9 seconds of slot time.\n",
341+
" Query processed 58.8 MB in 7 seconds of slot time.\n",
361342
" "
362343
],
363344
"text/plain": [
@@ -389,30 +370,23 @@
389370
}
390371
],
391372
"source": [
392-
"# Prepare the TimesFM forecast data.\n",
393-
"timesfm_result = result.sort_values(\"forecast_timestamp\")[ [ \"forecast_timestamp\", \"forecast_value\" ] ]\n",
373+
"timesfm_result = result.sort_values(\"forecast_timestamp\")[[\"forecast_timestamp\", \"forecast_value\"]]\n",
394374
"timesfm_result = timesfm_result.rename(columns={\n",
395375
" \"forecast_timestamp\": \"trip_hour\",\n",
396376
" \"forecast_value\": \"timesfm_forecast\"\n",
397377
"})\n",
398-
"\n",
399-
"# Prepare the ARIMAPlus forecast data.\n",
400-
"arimaplus_result = predictions.sort_values(\"forecast_timestamp\")[ [ \"forecast_timestamp\", \"forecast_value\" ] ]\n",
378+
"arimaplus_result = predictions.sort_values(\"forecast_timestamp\")[[\"forecast_timestamp\", \"forecast_value\"]]\n",
401379
"arimaplus_result = arimaplus_result.rename(columns={\n",
402380
" \"forecast_timestamp\": \"trip_hour\",\n",
403381
" \"forecast_value\": \"arimaplus_forecast\"\n",
404382
"})\n",
405-
"\n",
406-
"# Merge the forecasts with the original data.\n",
407383
"df_all = df_grouped.merge(timesfm_result, on=\"trip_hour\", how=\"left\")\n",
408384
"df_all = df_all.merge(arimaplus_result, on=\"trip_hour\", how=\"left\")\n",
409-
"\n",
410-
"# Plot the last 4 weeks of data for comparison.\n",
411-
"df_all.tail(672).plot.line( \n",
412-
" x=\"trip_hour\", \n",
413-
" y=[\"num_trips\", \"timesfm_forecast\", \"arimaplus_forecast\"], \n",
414-
" rot=45, \n",
415-
" title=\"Trip Forecasts Comparison\" \n",
385+
"df_all.tail(672).plot.line(\n",
386+
" x=\"trip_hour\",\n",
387+
" y=[\"num_trips\", \"timesfm_forecast\", \"arimaplus_forecast\"],\n",
388+
" rot=45,\n",
389+
" title=\"Trip Forecasts Comparison\"\n",
416390
")"
417391
]
418392
},
@@ -436,20 +410,20 @@
436410
"name": "stderr",
437411
"output_type": "stream",
438412
"text": [
439-
"/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/core/log_adapter.py:182: TimeTravelCacheWarning: Reading cached table from 2025-12-12 03:22:23.615364+00:00 to avoid\n",
413+
"/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/core/log_adapter.py:182: TimeTravelCacheWarning: Reading cached table from 2025-12-12 03:47:11.144938+00:00 to avoid\n",
440414
"incompatibilies with previous reads of this table. To read the latest\n",
441415
"version, set `use_cache=False` or close the current session with\n",
442416
"Session.close() or bigframes.pandas.close_session().\n",
443417
" return method(*args, **kwargs)\n",
444-
"/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/ml/forecasting.py:239: UserWarning: Converting Date column 'date' to datetime for hourly frequency. This is required because BigQuery ML doesn't support Date type with hourly frequency.\n",
418+
"/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/ml/forecasting.py:238: UserWarning: Converting Date column 'date' to datetime for hourly frequency. This is required because BigQuery ML doesn't support Date type with hourly frequency.\n",
445419
" warnings.warn(\n"
446420
]
447421
},
448422
{
449423
"data": {
450424
"text/html": [
451425
"\n",
452-
" Query processed 39.4 MB in 2 hours of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:c6f5d199-d64a-495f-b9f7-d6eaef4eb4b6&page=queryresults\">Job bigframes-dev:US.c6f5d199-d64a-495f-b9f7-d6eaef4eb4b6 details</a>]\n",
426+
" Query processed 39.4 MB in 2 hours of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:16d735c6-c885-447f-b513-5249ee8cb48a&page=queryresults\">Job bigframes-dev:US.16d735c6-c885-447f-b513-5249ee8cb48a details</a>]\n",
453427
" "
454428
],
455429
"text/plain": [
@@ -463,7 +437,7 @@
463437
"data": {
464438
"text/html": [
465439
"✅ Completed. \n",
466-
" Query processed 32.0 MB in 5 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:83063588-a945-4405-812a-87305cf640fe&page=queryresults\">Job bigframes-dev:US.83063588-a945-4405-812a-87305cf640fe details</a>]\n",
440+
" Query processed 32.0 MB in 3 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:c4cf8019-bdf6-461a-8753-e7314b837c29&page=queryresults\">Job bigframes-dev:US.c4cf8019-bdf6-461a-8753-e7314b837c29 details</a>]\n",
467441
" "
468442
],
469443
"text/plain": [
@@ -491,7 +465,7 @@
491465
"data": {
492466
"text/html": [
493467
"✅ Completed. \n",
494-
" Query processed 11.5 kB in 11 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:2c711378-2060-46a4-90b3-183a855463d4&page=queryresults\">Job bigframes-dev:US.2c711378-2060-46a4-90b3-183a855463d4 details</a>]\n",
468+
" Query processed 11.5 kB in 8 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:13663bcd-b3e2-471c-b7e7-50c260c4cfdd&page=queryresults\">Job bigframes-dev:US.13663bcd-b3e2-471c-b7e7-50c260c4cfdd details</a>]\n",
495469
" "
496470
],
497471
"text/plain": [
@@ -532,12 +506,12 @@
532506
{
533507
"data": {
534508
"application/vnd.jupyter.widget-view+json": {
535-
"model_id": "fc8931fb4cd0464ea1ebd790eda7d909",
509+
"model_id": "5cd08895656f45f5bfe39fd6bb26855f",
536510
"version_major": 2,
537511
"version_minor": 1
538512
},
539513
"text/plain": [
540-
"<bigframes.display.anywidget.TableWidget object at 0x7f62f0068550>"
514+
"<bigframes.display.anywidget.TableWidget object at 0x7efcec160550>"
541515
]
542516
},
543517
"metadata": {},
@@ -556,28 +530,21 @@
556530
}
557531
],
558532
"source": [
559-
"# Filter for specific stations to create a dataset with multiple distinct time series.\n",
560533
"df_multi = bpd.read_gbq(\"bigquery-public-data.san_francisco_bikeshare.bikeshare_trips\")\n",
561-
"df_multi = df_multi[df_multi[\"start_station_name\"] .str.contains(\"Market|Powell|Embarcadero\")]\n",
562-
"\n",
563-
"# Group the data by station and date to create a time series for each station.\n",
534+
"df_multi = df_multi[df_multi[\"start_station_name\"].str.contains(\"Market|Powell|Embarcadero\")]\n",
564535
"features = bpd.DataFrame({\n",
565536
" \"start_station_name\": df_multi[\"start_station_name\"],\n",
566537
" \"num_trips\": df_multi[\"start_date\"],\n",
567-
" \"date\": df_multi[\"start_date\"] .dt.date,\n",
538+
" \"date\": df_multi[\"start_date\"].dt.date,\n",
568539
"})\n",
569540
"num_trips = features.groupby(\n",
570-
" [ \"start_station_name\", \"date\" ], as_index=False\n",
571-
" ).count()\n",
572-
"\n",
573-
"# Fit the model, using the 'start_station_name' column to identify each individual time series.\n",
574-
"model.fit (\n",
541+
" [\"start_station_name\", \"date\"], as_index=False\n",
542+
").count()\n",
543+
"model.fit(\n",
575544
" num_trips[[\"date\"]],\n",
576545
" num_trips[[\"num_trips\"]],\n",
577-
" id_col=num_trips[[\"start_station_name\"] ]\n",
546+
" id_col=num_trips[[\"start_station_name\"]]\n",
578547
")\n",
579-
"\n",
580-
"# Predict the future values for each time series.\n",
581548
"predictions_multi = model.predict()\n",
582549
"predictions_multi"
583550
]

tests/system/large/ml/test_forecasting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from bigframes.ml import forecasting
1818
from bigframes.testing import utils
1919

20-
ARIMA_EVALUATE_OUTPUT_COL = [
20+
ARIMA_EVALUATE_OUTPUT_COLUMNS = [
2121
"non_seasonal_p",
2222
"non_seasonal_d",
2323
"non_seasonal_q",
@@ -106,9 +106,9 @@ def test_arima_plus_model_fit_summary(
106106
curr_model = arima_model_w_id if id_col_name else arima_model
107107
result = curr_model.summary().to_pandas()
108108
expected_columns = (
109-
[id_col_name] + ARIMA_EVALUATE_OUTPUT_COL
109+
[id_col_name] + ARIMA_EVALUATE_OUTPUT_COLUMNS
110110
if id_col_name
111-
else ARIMA_EVALUATE_OUTPUT_COL
111+
else ARIMA_EVALUATE_OUTPUT_COLUMNS
112112
)
113113
utils.check_pandas_df_schema_and_index(
114114
result, columns=expected_columns, index=2 if id_col_name else 1

0 commit comments

Comments
 (0)