|
18 | 18 | "outputs": [], |
19 | 19 | "source": [ |
20 | 20 | "import bigframes.pandas as bpd\n", |
| 21 | + "from bigframes.ml import forecasting\n", |
21 | 22 | "bpd.options.display.repr_mode = \"anywidget\"" |
22 | 23 | ] |
23 | 24 | }, |
|
38 | 39 | "metadata": {}, |
39 | 40 | "outputs": [], |
40 | 41 | "source": [ |
41 | | - "# Load the bikeshare dataset from the public BigQuery repository.\n", |
42 | 42 | "df = bpd.read_gbq(\"bigquery-public-data.san_francisco_bikeshare.bikeshare_trips\")\n", |
43 | | - "\n", |
44 | | - "# Filter the data to focus on a specific time period and user type.\n", |
45 | 43 | "df = df[df[\"start_date\"] >= \"2018-01-01\"]\n", |
46 | 44 | "df = df[df[\"subscriber_type\"] == \"Subscriber\"]\n", |
47 | | - "\n", |
48 | | - "# Resample the data to an hourly frequency by counting the number of trips in each hour.\n", |
49 | | - "df[\"trip_hour\"] = df[\"start_date\"] .dt.floor(\"h\")\n", |
| 45 | + "df[\"trip_hour\"] = df[\"start_date\"].dt.floor(\"h\")\n", |
50 | 46 | "df_grouped = df[[\"trip_hour\", \"trip_id\"]].groupby(\"trip_hour\").count().reset_index()\n", |
51 | 47 | "df_grouped = df_grouped.rename(columns={\"trip_id\": \"num_trips\"})" |
52 | 48 | ] |
|
80 | 76 | "data": { |
81 | 77 | "text/html": [ |
82 | 78 | "✅ Completed. \n", |
83 | | - " Query processed 58.7 MB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:8904286c-644a-409d-9ba0-fe308a3382bf&page=queryresults\">Job bigframes-dev:US.8904286c-644a-409d-9ba0-fe308a3382bf details</a>]\n", |
| 79 | + " Query processed 58.7 MB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:b91a9e6f-d00a-44f6-afe1-255e25945a1d&page=queryresults\">Job bigframes-dev:US.b91a9e6f-d00a-44f6-afe1-255e25945a1d details</a>]\n", |
84 | 80 | " " |
85 | 81 | ], |
86 | 82 | "text/plain": [ |
|
94 | 90 | "data": { |
95 | 91 | "text/html": [ |
96 | 92 | "✅ Completed. \n", |
97 | | - " Query processed 7.1 kB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:f859016e-cf03-4581-b176-e50c559f9380&page=queryresults\">Job bigframes-dev:US.f859016e-cf03-4581-b176-e50c559f9380 details</a>]\n", |
| 93 | + " Query processed 7.1 kB in 9 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:1f3bfd8b-5740-4895-be14-6a8b92a4f3b1&page=queryresults\">Job bigframes-dev:US.1f3bfd8b-5740-4895-be14-6a8b92a4f3b1 details</a>]\n", |
98 | 94 | " " |
99 | 95 | ], |
100 | 96 | "text/plain": [ |
|
135 | 131 | { |
136 | 132 | "data": { |
137 | 133 | "application/vnd.jupyter.widget-view+json": { |
138 | | - "model_id": "f53e2ae48801458ab42facd6ebe10728", |
| 134 | + "model_id": "8a4d64e6cf4844018c1b8593d1d99e05", |
139 | 135 | "version_major": 2, |
140 | 136 | "version_minor": 1 |
141 | 137 | }, |
142 | 138 | "text/plain": [ |
143 | | - "<bigframes.display.anywidget.TableWidget object at 0x7f62f2b13380>" |
| 139 | + "<bigframes.display.anywidget.TableWidget object at 0x7efcefb3bb60>" |
144 | 140 | ] |
145 | 141 | }, |
146 | 142 | "metadata": {}, |
|
159 | 155 | } |
160 | 156 | ], |
161 | 157 | "source": [ |
162 | | - "# Use the TimesFM model to forecast the last 168 hours (one week).\n", |
163 | | - "# The `timestamp_column` specifies the time index of the series.\n", |
164 | | - "# The `data_column` is the value we want to forecast.\n", |
165 | | - "# The `horizon` defines how many steps into the future to predict.\n", |
166 | 158 | "result = df_grouped.head(2842-168).ai.forecast(\n", |
167 | 159 | " timestamp_column=\"trip_hour\",\n", |
168 | 160 | " data_column=\"num_trips\",\n", |
|
191 | 183 | "data": { |
192 | 184 | "text/html": [ |
193 | 185 | "\n", |
194 | | - " Query processed 1.8 MB in 40 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:02689ad8-e003-4911-acfc-be2e5c75652d&page=queryresults\">Job bigframes-dev:US.02689ad8-e003-4911-acfc-be2e5c75652d details</a>]\n", |
| 186 | + " Query processed 1.8 MB in 47 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:36efa98e-2843-4bc9-8225-06875236ef17&page=queryresults\">Job bigframes-dev:US.36efa98e-2843-4bc9-8225-06875236ef17 details</a>]\n", |
195 | 187 | " " |
196 | 188 | ], |
197 | 189 | "text/plain": [ |
|
205 | 197 | "data": { |
206 | 198 | "text/html": [ |
207 | 199 | "✅ Completed. \n", |
208 | | - " Query processed 92.2 kB in a moment of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:9e29715c-0d9e-40db-8ead-d063791e61a5&page=queryresults\">Job bigframes-dev:US.9e29715c-0d9e-40db-8ead-d063791e61a5 details</a>]\n", |
| 200 | + " Query processed 92.2 kB in a moment of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:18805b62-a8e2-4c69-a5bf-97aa19df8095&page=queryresults\">Job bigframes-dev:US.18805b62-a8e2-4c69-a5bf-97aa19df8095 details</a>]\n", |
209 | 201 | " " |
210 | 202 | ], |
211 | 203 | "text/plain": [ |
|
233 | 225 | "data": { |
234 | 226 | "text/html": [ |
235 | 227 | "✅ Completed. \n", |
236 | | - " Query processed 10.8 kB in 16 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:bf966f53-b73a-42af-8cb7-5d2384620e3d&page=queryresults\">Job bigframes-dev:US.bf966f53-b73a-42af-8cb7-5d2384620e3d details</a>]\n", |
| 228 | + " Query processed 10.8 kB in 11 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:a2b15286-6c7f-40a7-8009-045e5e4f3dbf&page=queryresults\">Job bigframes-dev:US.a2b15286-6c7f-40a7-8009-045e5e4f3dbf details</a>]\n", |
237 | 229 | " " |
238 | 230 | ], |
239 | 231 | "text/plain": [ |
|
274 | 266 | { |
275 | 267 | "data": { |
276 | 268 | "application/vnd.jupyter.widget-view+json": { |
277 | | - "model_id": "a5360418703842cf9670f9d6ae77d8ee", |
| 269 | + "model_id": "b3816017ab4440c7bcf258df4a0ceff8", |
278 | 270 | "version_major": 2, |
279 | 271 | "version_minor": 1 |
280 | 272 | }, |
281 | 273 | "text/plain": [ |
282 | | - "<bigframes.display.anywidget.TableWidget object at 0x7f62f011fc50>" |
| 274 | + "<bigframes.display.anywidget.TableWidget object at 0x7efcec227c50>" |
283 | 275 | ] |
284 | 276 | }, |
285 | 277 | "metadata": {}, |
|
298 | 290 | } |
299 | 291 | ], |
300 | 292 | "source": [ |
301 | | - "from bigframes.ml import forecasting\n", |
302 | | - "\n", |
303 | | - "# Create and configure an ARIMAPlus model for hourly data.\n", |
304 | | - "# `auto_arima_max_order` is set to a lower value to reduce the training time.\n", |
305 | | - "# `data_frequency` is set to 'hourly' to match our aggregated data.\n", |
306 | 293 | "model = forecasting.ARIMAPlus(\n", |
307 | 294 | " auto_arima_max_order=5, # Reduce runtime for large datasets\n", |
308 | 295 | " data_frequency=\"hourly\",\n", |
309 | 296 | " horizon=168\n", |
310 | 297 | ")\n", |
311 | | - "\n", |
312 | | - "# Prepare the training data by excluding the last week.\n", |
313 | | - "X = df_grouped.head(2842-168)[[\"trip_hour\"] ]\n", |
314 | | - "y = df_grouped.head(2842-168)[[\"num_trips\"] ]\n", |
315 | | - "\n", |
316 | | - "# Fit the model to the training data.\n", |
| 298 | + "X = df_grouped.head(2842-168)[[\"trip_hour\"]]\n", |
| 299 | + "y = df_grouped.head(2842-168)[[\"num_trips\"]]\n", |
317 | 300 | "model.fit(\n", |
318 | 301 | " X, y\n", |
319 | 302 | ")\n", |
320 | | - "\n", |
321 | | - "# Generate predictions for the specified horizon.\n", |
322 | 303 | "predictions = model.predict(horizon=168, confidence_level=0.95)\n", |
323 | 304 | "predictions" |
324 | 305 | ] |
|
343 | 324 | "data": { |
344 | 325 | "text/html": [ |
345 | 326 | "✅ Completed. \n", |
346 | | - " Query processed 31.7 MB in 9 seconds of slot time.\n", |
| 327 | + " Query processed 31.7 MB in 10 seconds of slot time.\n", |
347 | 328 | " " |
348 | 329 | ], |
349 | 330 | "text/plain": [ |
|
357 | 338 | "data": { |
358 | 339 | "text/html": [ |
359 | 340 | "✅ Completed. \n", |
360 | | - " Query processed 58.8 MB in 9 seconds of slot time.\n", |
| 341 | + " Query processed 58.8 MB in 7 seconds of slot time.\n", |
361 | 342 | " " |
362 | 343 | ], |
363 | 344 | "text/plain": [ |
|
389 | 370 | } |
390 | 371 | ], |
391 | 372 | "source": [ |
392 | | - "# Prepare the TimesFM forecast data.\n", |
393 | | - "timesfm_result = result.sort_values(\"forecast_timestamp\")[ [ \"forecast_timestamp\", \"forecast_value\" ] ]\n", |
| 373 | + "timesfm_result = result.sort_values(\"forecast_timestamp\")[[\"forecast_timestamp\", \"forecast_value\"]]\n", |
394 | 374 | "timesfm_result = timesfm_result.rename(columns={\n", |
395 | 375 | " \"forecast_timestamp\": \"trip_hour\",\n", |
396 | 376 | " \"forecast_value\": \"timesfm_forecast\"\n", |
397 | 377 | "})\n", |
398 | | - "\n", |
399 | | - "# Prepare the ARIMAPlus forecast data.\n", |
400 | | - "arimaplus_result = predictions.sort_values(\"forecast_timestamp\")[ [ \"forecast_timestamp\", \"forecast_value\" ] ]\n", |
| 378 | + "arimaplus_result = predictions.sort_values(\"forecast_timestamp\")[[\"forecast_timestamp\", \"forecast_value\"]]\n", |
401 | 379 | "arimaplus_result = arimaplus_result.rename(columns={\n", |
402 | 380 | " \"forecast_timestamp\": \"trip_hour\",\n", |
403 | 381 | " \"forecast_value\": \"arimaplus_forecast\"\n", |
404 | 382 | "})\n", |
405 | | - "\n", |
406 | | - "# Merge the forecasts with the original data.\n", |
407 | 383 | "df_all = df_grouped.merge(timesfm_result, on=\"trip_hour\", how=\"left\")\n", |
408 | 384 | "df_all = df_all.merge(arimaplus_result, on=\"trip_hour\", how=\"left\")\n", |
409 | | - "\n", |
410 | | - "# Plot the last 4 weeks of data for comparison.\n", |
411 | | - "df_all.tail(672).plot.line( \n", |
412 | | - " x=\"trip_hour\", \n", |
413 | | - " y=[\"num_trips\", \"timesfm_forecast\", \"arimaplus_forecast\"], \n", |
414 | | - " rot=45, \n", |
415 | | - " title=\"Trip Forecasts Comparison\" \n", |
| 385 | + "df_all.tail(672).plot.line(\n", |
| 386 | + " x=\"trip_hour\",\n", |
| 387 | + " y=[\"num_trips\", \"timesfm_forecast\", \"arimaplus_forecast\"],\n", |
| 388 | + " rot=45,\n", |
| 389 | + " title=\"Trip Forecasts Comparison\"\n", |
416 | 390 | ")" |
417 | 391 | ] |
418 | 392 | }, |
|
436 | 410 | "name": "stderr", |
437 | 411 | "output_type": "stream", |
438 | 412 | "text": [ |
439 | | - "/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/core/log_adapter.py:182: TimeTravelCacheWarning: Reading cached table from 2025-12-12 03:22:23.615364+00:00 to avoid\n", |
| 413 | + "/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/core/log_adapter.py:182: TimeTravelCacheWarning: Reading cached table from 2025-12-12 03:47:11.144938+00:00 to avoid\n", |
440 | 414 | "incompatibilies with previous reads of this table. To read the latest\n", |
441 | 415 | "version, set `use_cache=False` or close the current session with\n", |
442 | 416 | "Session.close() or bigframes.pandas.close_session().\n", |
443 | 417 | " return method(*args, **kwargs)\n", |
444 | | - "/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/ml/forecasting.py:239: UserWarning: Converting Date column 'date' to datetime for hourly frequency. This is required because BigQuery ML doesn't support Date type with hourly frequency.\n", |
| 418 | + "/usr/local/google/home/shuowei/src/python-bigquery-dataframes/bigframes/ml/forecasting.py:238: UserWarning: Converting Date column 'date' to datetime for hourly frequency. This is required because BigQuery ML doesn't support Date type with hourly frequency.\n", |
445 | 419 | " warnings.warn(\n" |
446 | 420 | ] |
447 | 421 | }, |
448 | 422 | { |
449 | 423 | "data": { |
450 | 424 | "text/html": [ |
451 | 425 | "\n", |
452 | | - " Query processed 39.4 MB in 2 hours of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:c6f5d199-d64a-495f-b9f7-d6eaef4eb4b6&page=queryresults\">Job bigframes-dev:US.c6f5d199-d64a-495f-b9f7-d6eaef4eb4b6 details</a>]\n", |
| 426 | + " Query processed 39.4 MB in 2 hours of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:16d735c6-c885-447f-b513-5249ee8cb48a&page=queryresults\">Job bigframes-dev:US.16d735c6-c885-447f-b513-5249ee8cb48a details</a>]\n", |
453 | 427 | " " |
454 | 428 | ], |
455 | 429 | "text/plain": [ |
|
463 | 437 | "data": { |
464 | 438 | "text/html": [ |
465 | 439 | "✅ Completed. \n", |
466 | | - " Query processed 32.0 MB in 5 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:83063588-a945-4405-812a-87305cf640fe&page=queryresults\">Job bigframes-dev:US.83063588-a945-4405-812a-87305cf640fe details</a>]\n", |
| 440 | + " Query processed 32.0 MB in 3 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:c4cf8019-bdf6-461a-8753-e7314b837c29&page=queryresults\">Job bigframes-dev:US.c4cf8019-bdf6-461a-8753-e7314b837c29 details</a>]\n", |
467 | 441 | " " |
468 | 442 | ], |
469 | 443 | "text/plain": [ |
|
491 | 465 | "data": { |
492 | 466 | "text/html": [ |
493 | 467 | "✅ Completed. \n", |
494 | | - " Query processed 11.5 kB in 11 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:2c711378-2060-46a4-90b3-183a855463d4&page=queryresults\">Job bigframes-dev:US.2c711378-2060-46a4-90b3-183a855463d4 details</a>]\n", |
| 468 | + " Query processed 11.5 kB in 8 seconds of slot time. [<a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:13663bcd-b3e2-471c-b7e7-50c260c4cfdd&page=queryresults\">Job bigframes-dev:US.13663bcd-b3e2-471c-b7e7-50c260c4cfdd details</a>]\n", |
495 | 469 | " " |
496 | 470 | ], |
497 | 471 | "text/plain": [ |
|
532 | 506 | { |
533 | 507 | "data": { |
534 | 508 | "application/vnd.jupyter.widget-view+json": { |
535 | | - "model_id": "fc8931fb4cd0464ea1ebd790eda7d909", |
| 509 | + "model_id": "5cd08895656f45f5bfe39fd6bb26855f", |
536 | 510 | "version_major": 2, |
537 | 511 | "version_minor": 1 |
538 | 512 | }, |
539 | 513 | "text/plain": [ |
540 | | - "<bigframes.display.anywidget.TableWidget object at 0x7f62f0068550>" |
| 514 | + "<bigframes.display.anywidget.TableWidget object at 0x7efcec160550>" |
541 | 515 | ] |
542 | 516 | }, |
543 | 517 | "metadata": {}, |
|
556 | 530 | } |
557 | 531 | ], |
558 | 532 | "source": [ |
559 | | - "# Filter for specific stations to create a dataset with multiple distinct time series.\n", |
560 | 533 | "df_multi = bpd.read_gbq(\"bigquery-public-data.san_francisco_bikeshare.bikeshare_trips\")\n", |
561 | | - "df_multi = df_multi[df_multi[\"start_station_name\"] .str.contains(\"Market|Powell|Embarcadero\")]\n", |
562 | | - "\n", |
563 | | - "# Group the data by station and date to create a time series for each station.\n", |
| 534 | + "df_multi = df_multi[df_multi[\"start_station_name\"].str.contains(\"Market|Powell|Embarcadero\")]\n", |
564 | 535 | "features = bpd.DataFrame({\n", |
565 | 536 | " \"start_station_name\": df_multi[\"start_station_name\"],\n", |
566 | 537 | " \"num_trips\": df_multi[\"start_date\"],\n", |
567 | | - " \"date\": df_multi[\"start_date\"] .dt.date,\n", |
| 538 | + " \"date\": df_multi[\"start_date\"].dt.date,\n", |
568 | 539 | "})\n", |
569 | 540 | "num_trips = features.groupby(\n", |
570 | | - " [ \"start_station_name\", \"date\" ], as_index=False\n", |
571 | | - " ).count()\n", |
572 | | - "\n", |
573 | | - "# Fit the model, using the 'start_station_name' column to identify each individual time series.\n", |
574 | | - "model.fit (\n", |
| 541 | + " [\"start_station_name\", \"date\"], as_index=False\n", |
| 542 | + ").count()\n", |
| 543 | + "model.fit(\n", |
575 | 544 | " num_trips[[\"date\"]],\n", |
576 | 545 | " num_trips[[\"num_trips\"]],\n", |
577 | | - " id_col=num_trips[[\"start_station_name\"] ]\n", |
| 546 | + " id_col=num_trips[[\"start_station_name\"]]\n", |
578 | 547 | ")\n", |
579 | | - "\n", |
580 | | - "# Predict the future values for each time series.\n", |
581 | 548 | "predictions_multi = model.predict()\n", |
582 | 549 | "predictions_multi" |
583 | 550 | ] |
|
0 commit comments