|
19 | 19 | from __future__ import annotations |
20 | 20 |
|
21 | 21 | import json |
22 | | -from typing import Any, List, Literal, Mapping, Tuple, Union |
| 22 | +from typing import Any, Iterable, List, Literal, Mapping, Tuple, Union |
23 | 23 |
|
24 | 24 | import pandas as pd |
25 | 25 |
|
26 | | -from bigframes import clients, dtypes, series, session |
| 26 | +from bigframes import clients, dataframe, dtypes |
| 27 | +from bigframes import pandas as bpd |
| 28 | +from bigframes import series, session |
27 | 29 | from bigframes.core import convert, log_adapter |
| 30 | +from bigframes.ml import core as ml_core |
28 | 31 | from bigframes.operations import ai_ops, output_schemas |
29 | 32 |
|
30 | 33 | PROMPT_TYPE = Union[ |
@@ -548,6 +551,91 @@ def score( |
548 | 551 | return series_list[0]._apply_nary_op(operator, series_list[1:]) |
549 | 552 |
|
550 | 553 |
|
| 554 | +@log_adapter.method_logger(custom_base_name="bigquery_ai") |
| 555 | +def forecast( |
| 556 | + df: dataframe.DataFrame | pd.DataFrame, |
| 557 | + *, |
| 558 | + data_col: str, |
| 559 | + timestamp_col: str, |
| 560 | + model: str = "TimesFM 2.0", |
| 561 | + id_cols: Iterable[str] | None = None, |
| 562 | + horizon: int = 10, |
| 563 | + confidence_level: float = 0.95, |
| 564 | + context_window: int | None = None, |
| 565 | +) -> dataframe.DataFrame: |
| 566 | + """ |
| 567 | + Forecast time series at future horizon. Using Google Research's open source TimesFM(https://github.com/google-research/timesfm) model. |
| 568 | +
|
| 569 | + .. note:: |
| 570 | +
|
| 571 | + This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the |
| 572 | + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" |
| 573 | + and might have limited support. For more information, see the launch stage descriptions |
| 574 | + (https://cloud.google.com/products#product-launch-stages). |
| 575 | +
|
| 576 | + Args: |
| 577 | + df (DataFrame): |
| 578 | + The dataframe that contains the data that you want to forecast. It could be either a BigFrames Dataframe or |
| 579 | + a pandas DataFrame. If it's a pandas DataFrame, the global BigQuery session will be used to load the data. |
| 580 | + data_col (str): |
| 581 | + A str value that specifies the name of the data column. The data column contains the data to forecast. |
| 582 | + The data column must use one of the following data types: INT64, NUMERIC and FLOAT64 |
| 583 | + timestamp_col (str): |
| 584 | + A str value that specified the name of the time points column. |
| 585 | + The time points column provides the time points used to generate the forecast. |
| 586 | + The time points column must use one of the following data types: TIMESTAMP, DATE and DATETIME |
| 587 | + model (str, default "TimesFM 2.0"): |
| 588 | + A str value that specifies the name of the model. TimesFM 2.0 is the only supported value, and is the default value. |
| 589 | + id_cols (Iterable[str], optional): |
| 590 | + An iterable of str value that specifies the names of one or more ID columns. Each ID identifies a unique time series to forecast. |
| 591 | + Specify one or more values for this argument in order to forecast multiple time series using a single query. |
| 592 | + The columns that you specify must use one of the following data types: STRING, INT64, ARRAY<STRING> and ARRAY<INT64> |
| 593 | + horizon (int, default 10): |
| 594 | + An int value that specifies the number of time points to forecast. The default value is 10. The valid input range is [1, 10,000]. |
| 595 | + confidence_level (float, default 0.95): |
| 596 | + A FLOAT64 value that specifies the percentage of the future values that fall in the prediction interval. |
| 597 | + The default value is 0.95. The valid input range is [0, 1). |
| 598 | + context_window (int, optional): |
| 599 | + An int value that specifies the context window length used by BigQuery ML's built-in TimesFM model. |
| 600 | + The context window length determines how many of the most recent data points from the input time series are use by the model. |
| 601 | + If you don't specify a value, the AI.FORECAST function automatically chooses the smallest possible context window length to use |
| 602 | + that is still large enough to cover the number of time series data points in your input data. |
| 603 | +
|
| 604 | + Returns: |
| 605 | + DataFrame: |
| 606 | + The forecast dataframe matches that of the BigQuery AI.FORECAST function. |
| 607 | + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast |
| 608 | +
|
| 609 | + Raises: |
| 610 | + ValueError: when any column ID does not exist in the dataframe. |
| 611 | + """ |
| 612 | + |
| 613 | + if isinstance(df, pd.DataFrame): |
| 614 | + # Load the pandas DataFrame with global session |
| 615 | + df = bpd.read_pandas(df) |
| 616 | + |
| 617 | + columns = [timestamp_col, data_col] |
| 618 | + if id_cols: |
| 619 | + columns += id_cols |
| 620 | + for column in columns: |
| 621 | + if column not in df.columns: |
| 622 | + raise ValueError(f"Column `{column}` not found") |
| 623 | + |
| 624 | + options: dict[str, Union[int, float, str, Iterable[str]]] = { |
| 625 | + "data_col": data_col, |
| 626 | + "timestamp_col": timestamp_col, |
| 627 | + "model": model, |
| 628 | + "horizon": horizon, |
| 629 | + "confidence_level": confidence_level, |
| 630 | + } |
| 631 | + if id_cols: |
| 632 | + options["id_cols"] = id_cols |
| 633 | + if context_window: |
| 634 | + options["context_window"] = context_window |
| 635 | + |
| 636 | + return ml_core.BaseBqml(df._session).ai_forecast(input_data=df, options=options) |
| 637 | + |
| 638 | + |
551 | 639 | def _separate_context_and_series( |
552 | 640 | prompt: PROMPT_TYPE, |
553 | 641 | ) -> Tuple[List[str | None], List[series.Series]]: |
|
0 commit comments