From cdec057f9c7dca31b608ff5b938730c99330ea40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 24 Jul 2025 16:25:37 +0000 Subject: [PATCH 01/21] feat: add `allow_large_results` option to `read_gbq_query` --- bigframes/session/__init__.py | 89 +++++++++++++++++-- .../bigframes_vendored/pandas/io/gbq.py | 6 ++ 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 2c9dea2d19..127a9ca548 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -394,6 +394,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., + allow_large_results: bool = ..., ) -> dataframe.DataFrame: ... @@ -410,6 +411,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., + allow_large_results: bool = ..., ) -> pandas.Series: ... @@ -424,8 +426,8 @@ def read_gbq( filters: third_party_pandas_gbq.FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), - dry_run: bool = False - # Add a verify index argument that fails if the index is not unique. + dry_run: bool = False, + allow_large_results: bool = True, ) -> dataframe.DataFrame | pandas.Series: # TODO(b/281571214): Generate prompt to show the progress of read_gbq. if columns and col_order: @@ -445,6 +447,7 @@ def read_gbq( use_cache=use_cache, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) else: if configuration is not None: @@ -551,6 +554,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., + allow_large_results: bool = ..., ) -> dataframe.DataFrame: ... @@ -567,6 +571,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., + allow_large_results: bool = ..., ) -> pandas.Series: ... @@ -582,6 +587,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: third_party_pandas_gbq.FiltersType = (), dry_run: bool = False, + allow_large_results: bool = True, ) -> dataframe.DataFrame | pandas.Series: """Turn a SQL query into a DataFrame. @@ -631,9 +637,48 @@ def read_gbq_query( See also: :meth:`Session.read_gbq`. + Args: + query (str): + A SQL query to execute. + index_col (Iterable[str] or str, optional): + The column(s) to use as the index for the DataFrame. This can be + a single column name or a list of column names. If not provided, + a default index will be used. + columns (Iterable[str], optional): + The columns to read from the query result. If not + specified, all columns will be read. + configuration (dict, optional): + A dictionary of query job configuration options. See the + BigQuery REST API documentation for a list of available options: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query + max_results (int, optional): + The maximum number of rows to retrieve from the query + result. If not specified, all rows will be loaded. + use_cache (bool, optional): + Whether to use cached results for the query. Defaults to ``True``. + Setting this to ``False`` will force a re-execution of the query. + col_order (Iterable[str], optional): + The desired order of columns in the resulting DataFrame. This + parameter is deprecated and will be removed in a future version. + Use ``columns`` instead. + filters (list[tuple], optional): + A list of filters to apply to the data. Filters are specified + as a list of tuples, where each tuple contains a column name, + an operator (e.g., '==', '!='), and a value. + dry_run (bool, optional): + If ``True``, the function will not actually execute the query but + will instead return statistics about the query. Defaults to + ``False``. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. + Defaults to ``True``. + Returns: - bigframes.pandas.DataFrame: - A DataFrame representing results of the query or table. + bigframes.pandas.DataFrame or pandas.Series: + A DataFrame representing the result of the query. If ``dry_run`` + is ``True``, a ``pandas.Series`` containing query statistics is + returned. Raises: ValueError: @@ -657,6 +702,7 @@ def read_gbq_query( use_cache=use_cache, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) @overload @@ -714,9 +760,40 @@ def read_gbq_table( See also: :meth:`Session.read_gbq`. + Args: + table_id (str): + The identifier of the BigQuery table to read. + index_col (Iterable[str] or str, optional): + The column(s) to use as the index for the DataFrame. This can be + a single column name or a list of column names. If not provided, + a default index will be used. + columns (Iterable[str], optional): + The columns to read from the table. If not specified, all + columns will be read. + max_results (int, optional): + The maximum number of rows to retrieve from the table. If not + specified, all rows will be loaded. + filters (list[tuple], optional): + A list of filters to apply to the data. Filters are specified + as a list of tuples, where each tuple contains a column name, + an operator (e.g., '==', '!='), and a value. + use_cache (bool, optional): + Whether to use cached results for the query. Defaults to ``True``. + Setting this to ``False`` will force a re-execution of the query. + col_order (Iterable[str], optional): + The desired order of columns in the resulting DataFrame. This + parameter is deprecated and will be removed in a future version. + Use ``columns`` instead. + dry_run (bool, optional): + If ``True``, the function will not actually execute the query but + will instead return statistics about the table. Defaults to + ``False``. + Returns: - bigframes.pandas.DataFrame: - A DataFrame representing results of the query or table. + bigframes.pandas.DataFrame or pandas.Series: + A DataFrame representing the contents of the table. If + ``dry_run`` is ``True``, a ``pandas.Series`` containing table + statistics is returned. Raises: ValueError: diff --git a/third_party/bigframes_vendored/pandas/io/gbq.py b/third_party/bigframes_vendored/pandas/io/gbq.py index 3dae2b6bbe..c9b9ab9292 100644 --- a/third_party/bigframes_vendored/pandas/io/gbq.py +++ b/third_party/bigframes_vendored/pandas/io/gbq.py @@ -25,6 +25,7 @@ def read_gbq( filters: FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), + allow_large_results: bool = True, ): """Loads a DataFrame from BigQuery. @@ -156,6 +157,11 @@ def read_gbq( `configuration` to avoid conflicts. col_order (Iterable[str]): Alias for columns, retained for backwards compatibility. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. This + option is only applicable when ``query_or_table`` is a query. + Defaults to ``True``. Raises: bigframes.exceptions.DefaultIndexWarning: From 3b78da770d6b1bd5cbb6291bd8eb2bd6de8a6d5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 24 Jul 2025 18:25:18 +0000 Subject: [PATCH 02/21] add system test --- tests/system/small/test_session.py | 32 ++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 4bb1c6589a..fddb5b1d79 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -33,6 +33,7 @@ import pytest import bigframes +import bigframes.core.nodes as nodes import bigframes.dataframe import bigframes.dtypes import bigframes.ml.linear_model @@ -640,6 +641,37 @@ def test_read_gbq_with_configuration( assert df.shape == (9, 3) +def test_read_gbq_query_w_allow_large_results(session: bigframes.Session): + if not hasattr(session.bqclient, "default_job_creation_mode"): + pytest.skip("Jobless query only available on newer google-cloud-bigquery.") + + query = "SELECT 1" + + # Make sure we don't get a cached table. + configuration = {"query": {"useQueryCache": False}} + + # Very small results should wrap a local node. + df_false = session.read_gbq( + query, + configuration=configuration, + allow_large_results=False, + ) + assert df_false.shape == (1, 1) + roots_false = df_false._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadLocalNode) for node in roots_false) + assert not any(isinstance(node, nodes.ReadTableNode) for node in roots_false) + + # Large results allowed should wrap a table. + df_true = session.read_gbq( + query, + configuration=configuration, + allow_large_results=True, + ) + assert df_true.shape == (1, 1) + roots_true = df_true._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadTableNode) for node in roots_true) + + def test_read_gbq_with_custom_global_labels( session: bigframes.Session, scalars_table_id: str ): From 68f1a109cf68aea995eb870e3701fb3d7f5914bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 24 Jul 2025 18:30:31 +0000 Subject: [PATCH 03/21] add to pandas module --- bigframes/pandas/io/api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index 65435bd902..2acb4ae961 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -184,6 +184,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., + allow_large_results: bool = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -200,6 +201,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., + allow_large_results: bool = ..., ) -> pandas.Series: ... @@ -215,6 +217,7 @@ def read_gbq( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), dry_run: bool = False, + allow_large_results: bool = True, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query_or_table) return global_session.with_default_session( @@ -228,6 +231,7 @@ def read_gbq( use_cache=use_cache, col_order=col_order, dry_run=dry_run, + allow_large_results=allow_large_results, ) @@ -391,6 +395,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., + allow_large_results: bool = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -407,6 +412,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., + allow_large_results: bool = ..., ) -> pandas.Series: ... @@ -422,6 +428,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: vendored_pandas_gbq.FiltersType = (), dry_run: bool = False, + allow_large_results: bool = True, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query) return global_session.with_default_session( @@ -435,6 +442,7 @@ def read_gbq_query( col_order=col_order, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) From 3b770ae7aca91ada39a540024b1fe8d58a4dda1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Wed, 20 Aug 2025 20:32:23 +0000 Subject: [PATCH 04/21] default to global option --- bigframes/pandas/io/api.py | 6 +++--- bigframes/session/__init__.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index 2acb4ae961..7ef82a410b 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -395,7 +395,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -412,7 +412,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -428,7 +428,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: vendored_pandas_gbq.FiltersType = (), dry_run: bool = False, - allow_large_results: bool = True, + allow_large_results: Optional[bool] = None, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query) return global_session.with_default_session( diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index de97a81344..f82ca854bf 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -59,6 +59,7 @@ from bigframes import exceptions as bfe from bigframes import version +import bigframes._config import bigframes._config.bigquery_options as bigquery_options import bigframes.clients import bigframes.constants @@ -554,7 +555,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> dataframe.DataFrame: ... @@ -571,7 +572,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -587,7 +588,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: third_party_pandas_gbq.FiltersType = (), dry_run: bool = False, - allow_large_results: bool = True, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame | pandas.Series: """Turn a SQL query into a DataFrame. @@ -672,7 +673,7 @@ def read_gbq_query( allow_large_results (bool, optional): Whether to allow large query results. If ``True``, the query results can be larger than the maximum response size. - Defaults to ``True``. + Defaults to ``bpd.options.compute.allow_large_results``. Returns: bigframes.pandas.DataFrame or pandas.Series: @@ -693,6 +694,9 @@ def read_gbq_query( elif col_order: columns = col_order + if allow_large_results is None: + allow_large_results = bigframes._config.options._allow_large_results + return self._loader.read_gbq_query( # type: ignore # for dry_run overload query=query, index_col=index_col, From 10a8302f35a451bf650fefb564be843a4b8357af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Wed, 20 Aug 2025 21:04:12 +0000 Subject: [PATCH 05/21] fix unit test --- tests/unit/session/test_read_gbq_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/session/test_read_gbq_query.py b/tests/unit/session/test_read_gbq_query.py index afd9922426..1f9d2fb945 100644 --- a/tests/unit/session/test_read_gbq_query.py +++ b/tests/unit/session/test_read_gbq_query.py @@ -25,7 +25,7 @@ def test_read_gbq_query_sets_destination_table(): # Use partial ordering mode to skip column uniqueness checks. session = mocks.create_bigquery_session(ordering_mode="partial") - _ = session.read_gbq_query("SELECT 'my-test-query';") + _ = session.read_gbq_query("SELECT 'my-test-query';", allow_large_results=True) queries = session._queries # type: ignore configs = session._job_configs # type: ignore From 938fb89b40aaf2a77250651a77a9a5755871dcbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 14:58:25 +0000 Subject: [PATCH 06/21] tweak imports so I can manually run doctest pytest --doctest-modules bigframes/session/__init__.py::bigframes.session.Session.read_gbq_query --- bigframes/pandas/io/api.py | 20 ++++++++++++++++---- bigframes/session/__init__.py | 22 ++++++++++++---------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index d9aad0d9e2..49c3e97540 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -625,7 +625,11 @@ def from_glob_path( def _get_bqclient() -> bigquery.Client: - clients_provider = bigframes.session.clients.ClientsProvider( + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import clients + + clients_provider = clients.ClientsProvider( project=config.options.bigquery.project, location=config.options.bigquery.location, use_regional_endpoints=config.options.bigquery.use_regional_endpoints, @@ -639,11 +643,15 @@ def _get_bqclient() -> bigquery.Client: def _dry_run(query, bqclient) -> bigquery.QueryJob: + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import metrics as bf_metrics + job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True)) # Fix for b/435183833. Log metrics even if a Session isn't available. - if bigframes.session.metrics.LOGGING_NAME_ENV_VAR in os.environ: - metrics = bigframes.session.metrics.ExecutionMetrics() + if bf_metrics.LOGGING_NAME_ENV_VAR in os.environ: + metrics = bf_metrics.ExecutionMetrics() metrics.count_job_stats(job) return job @@ -653,6 +661,10 @@ def _set_default_session_location_if_possible(query): def _set_default_session_location_if_possible_deferred_query(create_query): + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session._io import bigquery + # Set the location as per the query if this is the first query the user is # running and: # (1) Default session has not started yet, and @@ -674,7 +686,7 @@ def _set_default_session_location_if_possible_deferred_query(create_query): query = create_query() bqclient = _get_bqclient() - if bigframes.session._io.bigquery.is_query(query): + if bigquery.is_query(query): # Intentionally run outside of the session so that we can detect the # location before creating the session. Since it's a dry_run, labels # aren't necessary. diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 446bf56d67..77d37924e1 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -133,6 +133,10 @@ def __init__( context: Optional[bigquery_options.BigQueryOptions] = None, clients_provider: Optional[bigframes.session.clients.ClientsProvider] = None, ): + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import anonymous_dataset, clients, loader, metrics + _warn_if_bf_version_is_obsolete() if context is None: @@ -168,7 +172,7 @@ def __init__( if clients_provider: self._clients_provider = clients_provider else: - self._clients_provider = bigframes.session.clients.ClientsProvider( + self._clients_provider = clients.ClientsProvider( project=context.project, location=self._location, use_regional_endpoints=context.use_regional_endpoints, @@ -220,15 +224,13 @@ def __init__( else bigframes.enums.DefaultIndexKind.NULL ) - self._metrics = bigframes.session.metrics.ExecutionMetrics() + self._metrics = metrics.ExecutionMetrics() self._function_session = bff_session.FunctionSession() - self._anon_dataset_manager = ( - bigframes.session.anonymous_dataset.AnonymousDatasetManager( - self._clients_provider.bqclient, - location=self._location, - session_id=self._session_id, - kms_key=self._bq_kms_key_name, - ) + self._anon_dataset_manager = anonymous_dataset.AnonymousDatasetManager( + self._clients_provider.bqclient, + location=self._location, + session_id=self._session_id, + kms_key=self._bq_kms_key_name, ) # Session temp tables don't support specifying kms key, so use anon dataset if kms key specified self._session_resource_manager = ( @@ -242,7 +244,7 @@ def __init__( self._temp_storage_manager = ( self._session_resource_manager or self._anon_dataset_manager ) - self._loader = bigframes.session.loader.GbqDataLoader( + self._loader = loader.GbqDataLoader( session=self, bqclient=self._clients_provider.bqclient, storage_manager=self._temp_storage_manager, From 78d29f0730dde5ea167c53a60363e64ad04c8cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 15:22:50 +0000 Subject: [PATCH 07/21] support index_col and columns --- .../session/_io/bigquery/read_gbq_query.py | 50 ++++++++++++++----- bigframes/session/loader.py | 2 + 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/bigframes/session/_io/bigquery/read_gbq_query.py b/bigframes/session/_io/bigquery/read_gbq_query.py index 70c83d7875..0cd4177255 100644 --- a/bigframes/session/_io/bigquery/read_gbq_query.py +++ b/bigframes/session/_io/bigquery/read_gbq_query.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Optional +from typing import cast, Iterable, Optional, Tuple from google.cloud import bigquery import google.cloud.bigquery.table @@ -28,6 +28,7 @@ import bigframes.core.blocks as blocks import bigframes.core.guid import bigframes.core.schema as schemata +import bigframes.enums import bigframes.session @@ -53,7 +54,11 @@ def create_dataframe_from_query_job_stats( def create_dataframe_from_row_iterator( - rows: google.cloud.bigquery.table.RowIterator, *, session: bigframes.session.Session + rows: google.cloud.bigquery.table.RowIterator, + *, + session: bigframes.session.Session, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind, + columns: Iterable[str], ) -> dataframe.DataFrame: """Convert a RowIterator into a DataFrame wrapping a LocalNode. @@ -61,11 +66,23 @@ def create_dataframe_from_row_iterator( 'jobless' case where there's no destination table. """ pa_table = rows.to_arrow() + bq_schema = list(rows.schema) - # TODO(tswast): Use array_value.promote_offsets() instead once that node is - # supported by the local engine. - offsets_col = bigframes.core.guid.generate_guid() - pa_table = pyarrow_utils.append_offsets(pa_table, offsets_col=offsets_col) + if not index_col or isinstance(index_col, bigframes.enums.DefaultIndexKind): + # We get a sequential index for free, so use that if no index is specified. + # TODO(tswast): Use array_value.promote_offsets() instead once that node is + # supported by the local engine. + offsets_col = bigframes.core.guid.generate_guid() + pa_table = pyarrow_utils.append_offsets(pa_table, offsets_col=offsets_col) + bq_schema += [bigquery.SchemaField(offsets_col, "INTEGER")] + index_columns: Tuple[str, ...] = (offsets_col,) + index_labels: Tuple[Optional[str], ...] = (None,) + elif isinstance(index_col, str): + index_columns = (index_col,) + index_labels = (index_col,) + else: + index_columns = tuple(index_col) + index_labels = cast(Tuple[Optional[str], ...], tuple(index_col)) # We use the ManagedArrowTable constructor directly, because the # results of to_arrow() should be the source of truth with regards @@ -74,17 +91,24 @@ def create_dataframe_from_row_iterator( # like the output of the BQ Storage Read API. mat = local_data.ManagedArrowTable( pa_table, - schemata.ArraySchema.from_bq_schema( - list(rows.schema) + [bigquery.SchemaField(offsets_col, "INTEGER")] - ), + schemata.ArraySchema.from_bq_schema(bq_schema), ) mat.validate() + column_labels = [ + field.name for field in rows.schema if field.name not in index_columns + ] + array_value = core.ArrayValue.from_managed(mat, session) block = blocks.Block( array_value, - (offsets_col,), - [field.name for field in rows.schema], - (None,), + index_columns=index_columns, + column_labels=column_labels, + index_labels=index_labels, ) - return dataframe.DataFrame(block) + df = dataframe.DataFrame(block) + + if columns: + df = df[list(columns)] + + return df diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 6500701324..82ab56a126 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -1044,6 +1044,8 @@ def read_gbq_query( return bf_read_gbq_query.create_dataframe_from_row_iterator( rows, session=self._session, + index_col=index_col, + columns=columns, ) # If there was no destination table and we've made it this far, that From 435c602fcd33c3bed6388c1abc7bc9ae486dd1c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 16:02:17 +0000 Subject: [PATCH 08/21] add system tests and fix pandas warning --- bigframes/dataframe.py | 2 +- .../small/session/test_read_gbq_query.py | 113 ++++++++++++++++++ tests/system/small/test_session.py | 32 ----- 3 files changed, 114 insertions(+), 33 deletions(-) create mode 100644 tests/system/small/session/test_read_gbq_query.py diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index c58cbaba6a..b4a0d0f5c4 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -4419,7 +4419,7 @@ def to_dict( allow_large_results: Optional[bool] = None, **kwargs, ) -> dict | list[dict]: - return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient, into, **kwargs) # type: ignore + return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient=orient, into=into, **kwargs) # type: ignore def to_excel( self, diff --git a/tests/system/small/session/test_read_gbq_query.py b/tests/system/small/session/test_read_gbq_query.py new file mode 100644 index 0000000000..c1408febca --- /dev/null +++ b/tests/system/small/session/test_read_gbq_query.py @@ -0,0 +1,113 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +import bigframes +import bigframes.core.nodes as nodes + + +def test_read_gbq_query_w_allow_large_results(session: bigframes.Session): + if not hasattr(session.bqclient, "default_job_creation_mode"): + pytest.skip("Jobless query only available on newer google-cloud-bigquery.") + + query = "SELECT 1" + + # Make sure we don't get a cached table. + configuration = {"query": {"useQueryCache": False}} + + # Very small results should wrap a local node. + df_false = session.read_gbq( + query, + configuration=configuration, + allow_large_results=False, + ) + assert df_false.shape == (1, 1) + roots_false = df_false._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadLocalNode) for node in roots_false) + assert not any(isinstance(node, nodes.ReadTableNode) for node in roots_false) + + # Large results allowed should wrap a table. + df_true = session.read_gbq( + query, + configuration=configuration, + allow_large_results=True, + ) + assert df_true.shape == (1, 1) + roots_true = df_true._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadTableNode) for node in roots_true) + + +def test_read_gbq_query_w_columns(session: bigframes.Session): + query = """ + SELECT 1 as int_col, + 'a' as str_col, + TIMESTAMP('2025-08-21 10:41:32.123456') as timestamp_col + """ + + result = session.read_gbq( + query, + columns=["timestamp_col", "int_col"], + ) + assert list(result.columns) == ["timestamp_col", "int_col"] + assert result.to_dict(orient="records") == [ + { + "timestamp_col": datetime.datetime( + 2025, 8, 21, 10, 41, 32, 123456, tzinfo=datetime.timezone.utc + ), + "int_col": 1, + } + ] + + +@pytest.mark.parametrize( + ("index_col", "expected_index_names"), + ( + pytest.param( + "my_custom_index", + ("my_custom_index",), + id="string", + ), + pytest.param( + ("my_custom_index",), + ("my_custom_index",), + id="iterable", + ), + pytest.param( + ("my_custom_index", "int_col"), + ("my_custom_index", "int_col"), + id="multiindex", + ), + ), +) +def test_read_gbq_query_w_index_col( + session: bigframes.Session, index_col, expected_index_names +): + query = """ + SELECT 1 as int_col, + 'a' as str_col, + 0 as my_custom_index, + TIMESTAMP('2025-08-21 10:41:32.123456') as timestamp_col + """ + + result = session.read_gbq( + query, + index_col=index_col, + ) + assert tuple(result.index.names) == expected_index_names + assert frozenset(result.columns) == frozenset( + {"int_col", "str_col", "my_custom_index", "timestamp_col"} + ) - frozenset(expected_index_names) diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index f480984c5b..a04da64af0 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -33,7 +33,6 @@ import pytest import bigframes -import bigframes.core.nodes as nodes import bigframes.dataframe import bigframes.dtypes import bigframes.ml.linear_model @@ -641,37 +640,6 @@ def test_read_gbq_with_configuration( assert df.shape == (9, 3) -def test_read_gbq_query_w_allow_large_results(session: bigframes.Session): - if not hasattr(session.bqclient, "default_job_creation_mode"): - pytest.skip("Jobless query only available on newer google-cloud-bigquery.") - - query = "SELECT 1" - - # Make sure we don't get a cached table. - configuration = {"query": {"useQueryCache": False}} - - # Very small results should wrap a local node. - df_false = session.read_gbq( - query, - configuration=configuration, - allow_large_results=False, - ) - assert df_false.shape == (1, 1) - roots_false = df_false._get_block().expr.node.roots - assert any(isinstance(node, nodes.ReadLocalNode) for node in roots_false) - assert not any(isinstance(node, nodes.ReadTableNode) for node in roots_false) - - # Large results allowed should wrap a table. - df_true = session.read_gbq( - query, - configuration=configuration, - allow_large_results=True, - ) - assert df_true.shape == (1, 1) - roots_true = df_true._get_block().expr.node.roots - assert any(isinstance(node, nodes.ReadTableNode) for node in roots_true) - - def test_read_gbq_with_custom_global_labels( session: bigframes.Session, scalars_table_id: str ): From 778746ffdf4458039c15aaf890c151f7b1b86a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 16:17:14 +0000 Subject: [PATCH 09/21] use global option in remaining functions --- bigframes/pandas/io/api.py | 6 +++--- bigframes/session/__init__.py | 16 +++++++++------- bigframes/session/loader.py | 6 +++--- third_party/bigframes_vendored/pandas/io/gbq.py | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index 49c3e97540..483bc5e530 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -187,7 +187,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -204,7 +204,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -220,7 +220,7 @@ def read_gbq( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), dry_run: bool = False, - allow_large_results: bool = True, + allow_large_results: Optional[bool] = None, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query_or_table) return global_session.with_default_session( diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 77d37924e1..e1307dc9fa 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -398,7 +398,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> dataframe.DataFrame: ... @@ -415,7 +415,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., - allow_large_results: bool = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -431,7 +431,7 @@ def read_gbq( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), dry_run: bool = False, - allow_large_results: bool = True, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame | pandas.Series: # TODO(b/281571214): Generate prompt to show the progress of read_gbq. if columns and col_order: @@ -441,6 +441,9 @@ def read_gbq( elif col_order: columns = col_order + if allow_large_results is None: + allow_large_results = bigframes._config.options._allow_large_results + if bf_io_bigquery.is_query(query_or_table): return self._loader.read_gbq_query( # type: ignore # for dry_run overload query_or_table, @@ -527,6 +530,8 @@ def _read_gbq_colab( if pyformat_args is None: pyformat_args = {} + allow_large_results = bigframes._config.options._allow_large_results + query = bigframes.core.pyformat.pyformat( query, pyformat_args=pyformat_args, @@ -539,10 +544,7 @@ def _read_gbq_colab( index_col=bigframes.enums.DefaultIndexKind.NULL, force_total_order=False, dry_run=typing.cast(Union[Literal[False], Literal[True]], dry_run), - # TODO(tswast): we may need to allow allow_large_results to be overwritten - # or possibly a general configuration object for an explicit - # destination table and write disposition. - allow_large_results=False, + allow_large_results=allow_large_results, ) @overload diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 82ab56a126..131965eeb9 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -895,7 +895,7 @@ def read_gbq_query( # type: ignore[overload-overlap] filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., force_total_order: Optional[bool] = ..., - allow_large_results: bool = ..., + allow_large_results: bool, ) -> dataframe.DataFrame: ... @@ -912,7 +912,7 @@ def read_gbq_query( filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., force_total_order: Optional[bool] = ..., - allow_large_results: bool = ..., + allow_large_results: bool, ) -> pandas.Series: ... @@ -928,7 +928,7 @@ def read_gbq_query( filters: third_party_pandas_gbq.FiltersType = (), dry_run: bool = False, force_total_order: Optional[bool] = None, - allow_large_results: bool = True, + allow_large_results: bool, ) -> dataframe.DataFrame | pandas.Series: configuration = _transform_read_gbq_configuration(configuration) diff --git a/third_party/bigframes_vendored/pandas/io/gbq.py b/third_party/bigframes_vendored/pandas/io/gbq.py index c9b9ab9292..0fdca4dde1 100644 --- a/third_party/bigframes_vendored/pandas/io/gbq.py +++ b/third_party/bigframes_vendored/pandas/io/gbq.py @@ -25,7 +25,7 @@ def read_gbq( filters: FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), - allow_large_results: bool = True, + allow_large_results: Optional[bool] = None, ): """Loads a DataFrame from BigQuery. @@ -161,7 +161,7 @@ def read_gbq( Whether to allow large query results. If ``True``, the query results can be larger than the maximum response size. This option is only applicable when ``query_or_table`` is a query. - Defaults to ``True``. + Defaults to ``bpd.options.compute.allow_large_results``. Raises: bigframes.exceptions.DefaultIndexWarning: From 95b2fdf94d0d8c2c76a499888a639c4544dc3242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 17:24:24 +0000 Subject: [PATCH 10/21] supply allow_large_results=False when max_results is set --- bigframes/session/loader.py | 12 +++++++++++- tests/system/small/test_session.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 131965eeb9..0d6d4aa448 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -721,6 +721,9 @@ def read_gbq_table( columns=columns, use_cache=use_cache, dry_run=dry_run, + # If max_results has been set, we almost certainly have < 10 GB + # of results. + allow_large_results=False, ) return df @@ -1040,7 +1043,14 @@ def read_gbq_query( # local node. Likely there are a wide range of sizes in which it # makes sense to download the results beyond the first page, even if # there is a job and destination table available. - if rows is not None and destination is None: + if ( + rows is not None + and destination is None + and ( + query_job_for_metrics is None + or query_job_for_metrics.statement_type == "SELECT" + ) + ): return bf_read_gbq_query.create_dataframe_from_row_iterator( rows, session=self._session, diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index a04da64af0..40fcb150f6 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -619,7 +619,7 @@ def test_read_gbq_wildcard( pytest.param( {"query": {"useQueryCache": False, "maximumBytesBilled": "100"}}, marks=pytest.mark.xfail( - raises=google.api_core.exceptions.InternalServerError, + raises=google.api_core.exceptions.BadRequest, reason="Expected failure when the query exceeds the maximum bytes billed limit.", ), ), From 05c145c9bdf65381b07fea045980159c3eaf524e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 18:08:56 +0000 Subject: [PATCH 11/21] fix last? failing test --- tests/system/small/test_pandas_options.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/system/small/test_pandas_options.py b/tests/system/small/test_pandas_options.py index 1d360e0d4f..7a750ddfd3 100644 --- a/tests/system/small/test_pandas_options.py +++ b/tests/system/small/test_pandas_options.py @@ -280,6 +280,17 @@ def test_credentials_need_reauthentication( session = bpd.get_global_session() assert session.bqclient._http.credentials.valid + # We look at the thread-local session because of the + # reset_default_session_and_location fixture and that this test mutates + # state that might otherwise be used by tests running in parallel. + current_session = ( + bigframes.core.global_session._global_session_state.thread_local_session + ) + assert current_session is not None + + # Force a temp table to be created, so there is something to cleanup. + current_session._anon_dataset_manager.create_temp_table(schema=()) + with monkeypatch.context() as m: # Simulate expired credentials to trigger the credential refresh flow m.setattr( @@ -303,15 +314,6 @@ def test_credentials_need_reauthentication( with pytest.raises(google.auth.exceptions.RefreshError): bpd.read_gbq(test_query) - # Now verify that closing the session works We look at the - # thread-local session because of the - # reset_default_session_and_location fixture and that this test mutates - # state that might otherwise be used by tests running in parallel. - assert ( - bigframes.core.global_session._global_session_state.thread_local_session - is not None - ) - with warnings.catch_warnings(record=True) as warned: bpd.close_session() # CleanupFailedWarning: can't clean up From 58f45f835990bab7e89690f62b36eab67bde07a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 22:02:50 +0000 Subject: [PATCH 12/21] fix vector search tests --- .../session/_io/bigquery/read_gbq_query.py | 9 +- bigframes/session/loader.py | 1 + .../small/bigquery/test_vector_search.py | 107 ++++++------------ tests/system/small/test_unordered.py | 2 +- 4 files changed, 45 insertions(+), 74 deletions(-) diff --git a/bigframes/session/_io/bigquery/read_gbq_query.py b/bigframes/session/_io/bigquery/read_gbq_query.py index 0cd4177255..aed77615ce 100644 --- a/bigframes/session/_io/bigquery/read_gbq_query.py +++ b/bigframes/session/_io/bigquery/read_gbq_query.py @@ -67,8 +67,11 @@ def create_dataframe_from_row_iterator( """ pa_table = rows.to_arrow() bq_schema = list(rows.schema) + is_default_index = not index_col or isinstance( + index_col, bigframes.enums.DefaultIndexKind + ) - if not index_col or isinstance(index_col, bigframes.enums.DefaultIndexKind): + if is_default_index: # We get a sequential index for free, so use that if no index is specified. # TODO(tswast): Use array_value.promote_offsets() instead once that node is # supported by the local engine. @@ -81,6 +84,7 @@ def create_dataframe_from_row_iterator( index_columns = (index_col,) index_labels = (index_col,) else: + index_col = cast(Iterable[str], index_col) index_columns = tuple(index_col) index_labels = cast(Tuple[Optional[str], ...], tuple(index_col)) @@ -111,4 +115,7 @@ def create_dataframe_from_row_iterator( if columns: df = df[list(columns)] + if not is_default_index: + df = df.sort_index() + return df diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 0d6d4aa448..49b1195235 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -956,6 +956,7 @@ def read_gbq_query( True if use_cache is None else use_cache ) + _check_duplicates("columns", columns) index_cols = _to_index_cols(index_col) _check_index_col_param(index_cols, columns) diff --git a/tests/system/small/bigquery/test_vector_search.py b/tests/system/small/bigquery/test_vector_search.py index a282135fa6..608294ba46 100644 --- a/tests/system/small/bigquery/test_vector_search.py +++ b/tests/system/small/bigquery/test_vector_search.py @@ -157,80 +157,43 @@ def test_vector_search_basic_params_with_df(): ) -def test_vector_search_different_params_with_query(): - search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]]) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - distance_type="cosine", - top_k=2, - ).to_pandas() # type:ignore - expected = pd.DataFrame( +def test_vector_search_different_params_with_query(session): + base_df = bpd.DataFrame( { - "0": [ - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([3.0, 5.2]), - np.array([3.0, 5.2]), - ], - "id": [2, 1, 1, 2], + "id": [1, 2, 3, 4], "my_embedding": [ - np.array([2.0, 4.0]), - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([2.0, 4.0]), + np.array([0.0, 1.0]), + np.array([1.0, 0.0]), + np.array([0.0, -1.0]), + np.array([-1.0, 0.0]), ], - "distance": [0.0, 0.0, 0.001777, 0.001777], }, - index=pd.Index([0, 0, 1, 1], dtype="Int64"), - ) - pd.testing.assert_frame_equal( - vector_search_result, expected, check_dtype=False, rtol=0.1 - ) - - -def test_vector_search_df_with_query_column_to_search(): - search_query = bpd.DataFrame( - { - "query_id": ["dog", "cat"], - "embedding": [[1.0, 2.0], [3.0, 5.2]], - "another_embedding": [[1.0, 2.5], [3.3, 5.2]], - } - ) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - query_column_to_search="another_embedding", - top_k=2, - ).to_pandas() # type:ignore - expected = pd.DataFrame( - { - "query_id": ["dog", "dog", "cat", "cat"], - "embedding": [ - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([3.0, 5.2]), - np.array([3.0, 5.2]), - ], - "another_embedding": [ - np.array([1.0, 2.5]), - np.array([1.0, 2.5]), - np.array([3.3, 5.2]), - np.array([3.3, 5.2]), - ], - "id": [1, 4, 2, 5], - "my_embedding": [ - np.array([1.0, 2.0]), - np.array([1.0, 3.2]), - np.array([2.0, 4.0]), - np.array([5.0, 5.4]), - ], - "distance": [0.5, 0.7, 1.769181, 1.711724], - }, - index=pd.Index([0, 0, 1, 1], dtype="Int64"), - ) - pd.testing.assert_frame_equal( - vector_search_result, expected, check_dtype=False, rtol=0.1 + session=session, ) + base_table = base_df.to_gbq() + try: + search_query = bpd.Series([[0.75, 0.25], [-0.25, -0.75]], session=session) + vector_search_result = bbq.vector_search( + base_table=base_table, + column_to_search="my_embedding", + query=search_query, + distance_type="cosine", + top_k=2, + ).to_pandas() # type:ignore + expected = pd.DataFrame( + { + "0": {np.int64(0): [0.75, 0.25], np.int64(1): [-0.25, -0.75]}, + "id": {np.int64(0): 1, np.int64(1): 4}, + "my_embedding": {np.int64(0): [0.0, 1.0], np.int64(1): [-1.0, 0.0]}, + "distance": { + np.int64(0): 0.683772233983162, + np.int64(1): 0.683772233983162, + }, + }, + index=pd.Index([0, 0, 1, 1], dtype="Int64"), + ) + pd.testing.assert_frame_equal( + vector_search_result, expected, check_dtype=False, rtol=0.1 + ) + finally: + session.bqclient.delete_table(base_table, not_found_ok=True) diff --git a/tests/system/small/test_unordered.py b/tests/system/small/test_unordered.py index 0825b78037..c4f6521642 100644 --- a/tests/system/small/test_unordered.py +++ b/tests/system/small/test_unordered.py @@ -103,7 +103,7 @@ def test_unordered_mode_read_gbq(unordered_session): } ) # Don't need ignore_order as there is only 1 row - assert_pandas_df_equal(df.to_pandas(), expected) + assert_pandas_df_equal(df.to_pandas(), expected, check_index_type=False) @pytest.mark.parametrize( From 6d0fe48ac9f2e711f87edb10fca5dc0feff79fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 22:37:49 +0000 Subject: [PATCH 13/21] fix vector search tests again --- .../small/bigquery/test_vector_search.py | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/system/small/bigquery/test_vector_search.py b/tests/system/small/bigquery/test_vector_search.py index 608294ba46..3107795730 100644 --- a/tests/system/small/bigquery/test_vector_search.py +++ b/tests/system/small/bigquery/test_vector_search.py @@ -123,12 +123,17 @@ def test_vector_search_basic_params_with_df(): "embedding": [[1.0, 2.0], [3.0, 5.2]], } ) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - top_k=2, - ).to_pandas() # type:ignore + vector_search_result = ( + bbq.vector_search( + base_table="bigframes-dev.bigframes_tests_sys.base_table", + column_to_search="my_embedding", + query=search_query, + top_k=2, + ) + .sort_values("distance") + .sort_index() + .to_pandas() + ) # type:ignore expected = pd.DataFrame( { "query_id": ["cat", "dog", "dog", "cat"], @@ -173,22 +178,39 @@ def test_vector_search_different_params_with_query(session): base_table = base_df.to_gbq() try: search_query = bpd.Series([[0.75, 0.25], [-0.25, -0.75]], session=session) - vector_search_result = bbq.vector_search( - base_table=base_table, - column_to_search="my_embedding", - query=search_query, - distance_type="cosine", - top_k=2, - ).to_pandas() # type:ignore + vector_search_result = ( + bbq.vector_search( + base_table=base_table, + column_to_search="my_embedding", + query=search_query, + distance_type="cosine", + top_k=2, + ) + .sort_values("distance") + .sort_index() + .to_pandas() + ) # type:ignore expected = pd.DataFrame( { - "0": {np.int64(0): [0.75, 0.25], np.int64(1): [-0.25, -0.75]}, - "id": {np.int64(0): 1, np.int64(1): 4}, - "my_embedding": {np.int64(0): [0.0, 1.0], np.int64(1): [-1.0, 0.0]}, - "distance": { - np.int64(0): 0.683772233983162, - np.int64(1): 0.683772233983162, - }, + "0": [ + [0.75, 0.25], + [0.75, 0.25], + [-0.25, -0.75], + [-0.25, -0.75], + ], + "id": [2, 1, 3, 4], + "my_embedding": [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, -1.0], + [-1.0, 0.0], + ], + "distance": [ + 0.051317, + 0.683772, + 0.051317, + 0.683772, + ], }, index=pd.Index([0, 0, 1, 1], dtype="Int64"), ) From 8c27c2c2a3ed3cee64d6b862fb9ffedd23655e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 23:08:41 +0000 Subject: [PATCH 14/21] fix arima tests --- tests/system/small/ml/test_forecasting.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index d1b6b18fbe..be8ccbf749 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -432,8 +432,10 @@ def test_arima_plus_detect_anomalies_params( }, ) pd.testing.assert_frame_equal( - anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]], - expected, + anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]] + .sort_values("anomaly_probability") + .reset_index(), + expected.sort_values("anomaly_probability").reset_index(), rtol=0.1, check_index_type=False, check_dtype=False, @@ -484,8 +486,8 @@ def test_arima_plus_score( dtype="Float64", ) pd.testing.assert_frame_equal( - result, - expected, + result.sort_values("id").reset_index(), + expected.sort_values("id").reset_index(), rtol=0.1, check_index_type=False, ) @@ -577,8 +579,8 @@ def test_arima_plus_score_series( dtype="Float64", ) pd.testing.assert_frame_equal( - result, - expected, + result.sort_values("id").reset_index(), + expected.sort_values("id").reset_index(), rtol=0.1, check_index_type=False, ) From 5ab21a9cc199149fff7a9169901143fa105f767e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 23:12:59 +0000 Subject: [PATCH 15/21] fix more tests --- tests/system/small/ml/test_forecasting.py | 2 ++ tests/system/small/ml/test_preprocessing.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index be8ccbf749..bf5684f5d9 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -490,6 +490,7 @@ def test_arima_plus_score( expected.sort_values("id").reset_index(), rtol=0.1, check_index_type=False, + check_dtype=False, ) @@ -583,6 +584,7 @@ def test_arima_plus_score_series( expected.sort_values("id").reset_index(), rtol=0.1, check_index_type=False, + check_dtype=False, ) diff --git a/tests/system/small/ml/test_preprocessing.py b/tests/system/small/ml/test_preprocessing.py index 34be48be1e..65a851efc3 100644 --- a/tests/system/small/ml/test_preprocessing.py +++ b/tests/system/small/ml/test_preprocessing.py @@ -245,7 +245,7 @@ def test_max_abs_scaler_save_load(new_penguins_df, dataset_id): index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), ) - pd.testing.assert_frame_equal(result, expected, rtol=0.1) + pd.testing.assert_frame_equal(result.sort_index(), expected.sort_index(), rtol=0.1) def test_min_max_scaler_normalized_fit_transform(new_penguins_df): From 2ca9e9e954ca2b8527017496ef184df8374206b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 23:17:16 +0000 Subject: [PATCH 16/21] try again --- tests/system/small/ml/test_forecasting.py | 87 ++++++++++++++--------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index bf5684f5d9..03da743ba4 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -451,26 +451,35 @@ def test_arima_plus_score( id_col_name, ): if id_col_name: - result = time_series_arima_plus_model_w_id.score( - new_time_series_df_w_id[["parsed_date"]], - new_time_series_df_w_id[["total_visits"]], - new_time_series_df_w_id[["id"]], - ).to_pandas() + result = ( + time_series_arima_plus_model_w_id.score( + new_time_series_df_w_id[["parsed_date"]], + new_time_series_df_w_id[["total_visits"]], + new_time_series_df_w_id[["id"]], + ) + .to_pandas() + .sort_values("id") + .reset_index() + ) else: result = time_series_arima_plus_model.score( new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]] ).to_pandas() if id_col_name: - expected = pd.DataFrame( - { - "id": ["2", "1"], - "mean_absolute_error": [120.011007, 120.011007], - "mean_squared_error": [14562.562359, 14562.562359], - "root_mean_squared_error": [120.675442, 120.675442], - "mean_absolute_percentage_error": [4.80044, 4.80044], - "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], - }, - dtype="Float64", + expected = ( + pd.DataFrame( + { + "id": ["2", "1"], + "mean_absolute_error": [120.011007, 120.011007], + "mean_squared_error": [14562.562359, 14562.562359], + "root_mean_squared_error": [120.675442, 120.675442], + "mean_absolute_percentage_error": [4.80044, 4.80044], + "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], + }, + dtype="Float64", + ) + .sort_values("id") + .reset_index() ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") @@ -486,8 +495,8 @@ def test_arima_plus_score( dtype="Float64", ) pd.testing.assert_frame_equal( - result.sort_values("id").reset_index(), - expected.sort_values("id").reset_index(), + result, + expected, rtol=0.1, check_index_type=False, check_dtype=False, @@ -545,26 +554,35 @@ def test_arima_plus_score_series( id_col_name, ): if id_col_name: - result = time_series_arima_plus_model_w_id.score( - new_time_series_df_w_id["parsed_date"], - new_time_series_df_w_id["total_visits"], - new_time_series_df_w_id["id"], - ).to_pandas() + result = ( + time_series_arima_plus_model_w_id.score( + new_time_series_df_w_id["parsed_date"], + new_time_series_df_w_id["total_visits"], + new_time_series_df_w_id["id"], + ) + .to_pandas() + .sort_values("id") + .reset_index() + ) else: result = time_series_arima_plus_model.score( new_time_series_df["parsed_date"], new_time_series_df["total_visits"] ).to_pandas() if id_col_name: - expected = pd.DataFrame( - { - "id": ["2", "1"], - "mean_absolute_error": [120.011007, 120.011007], - "mean_squared_error": [14562.562359, 14562.562359], - "root_mean_squared_error": [120.675442, 120.675442], - "mean_absolute_percentage_error": [4.80044, 4.80044], - "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], - }, - dtype="Float64", + expected = ( + pd.DataFrame( + { + "id": ["2", "1"], + "mean_absolute_error": [120.011007, 120.011007], + "mean_squared_error": [14562.562359, 14562.562359], + "root_mean_squared_error": [120.675442, 120.675442], + "mean_absolute_percentage_error": [4.80044, 4.80044], + "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], + }, + dtype="Float64", + ) + .sort_values("id") + .reset_index() ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") @@ -580,11 +598,10 @@ def test_arima_plus_score_series( dtype="Float64", ) pd.testing.assert_frame_equal( - result.sort_values("id").reset_index(), - expected.sort_values("id").reset_index(), + result, + expected, rtol=0.1, check_index_type=False, - check_dtype=False, ) From 0c8b88cdd3ee85eab0b59ae5cbd223d11713907d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Thu, 21 Aug 2025 23:23:21 +0000 Subject: [PATCH 17/21] and again --- tests/system/small/ml/test_forecasting.py | 61 +++++++++++------------ 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 03da743ba4..134f82e96e 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -434,8 +434,8 @@ def test_arima_plus_detect_anomalies_params( pd.testing.assert_frame_equal( anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]] .sort_values("anomaly_probability") - .reset_index(), - expected.sort_values("anomaly_probability").reset_index(), + .reset_index(drop=True), + expected.sort_values("anomaly_probability").reset_index(drop=True), rtol=0.1, check_index_type=False, check_dtype=False, @@ -459,30 +459,28 @@ def test_arima_plus_score( ) .to_pandas() .sort_values("id") - .reset_index() + .reset_index(drop=True) ) else: result = time_series_arima_plus_model.score( new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]] ).to_pandas() if id_col_name: - expected = ( - pd.DataFrame( - { - "id": ["2", "1"], - "mean_absolute_error": [120.011007, 120.011007], - "mean_squared_error": [14562.562359, 14562.562359], - "root_mean_squared_error": [120.675442, 120.675442], - "mean_absolute_percentage_error": [4.80044, 4.80044], - "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], - }, - dtype="Float64", - ) - .sort_values("id") - .reset_index() + expected = pd.DataFrame( + { + "id": ["2", "1"], + "mean_absolute_error": [120.011007, 120.011007], + "mean_squared_error": [14562.562359, 14562.562359], + "root_mean_squared_error": [120.675442, 120.675442], + "mean_absolute_percentage_error": [4.80044, 4.80044], + "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], + }, + dtype="Float64", ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") + expected = expected.sort_values("id") + expected = expected.reset_index(drop=True) else: expected = pd.DataFrame( { @@ -562,30 +560,28 @@ def test_arima_plus_score_series( ) .to_pandas() .sort_values("id") - .reset_index() + .reset_index(drop=True) ) else: result = time_series_arima_plus_model.score( new_time_series_df["parsed_date"], new_time_series_df["total_visits"] ).to_pandas() if id_col_name: - expected = ( - pd.DataFrame( - { - "id": ["2", "1"], - "mean_absolute_error": [120.011007, 120.011007], - "mean_squared_error": [14562.562359, 14562.562359], - "root_mean_squared_error": [120.675442, 120.675442], - "mean_absolute_percentage_error": [4.80044, 4.80044], - "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], - }, - dtype="Float64", - ) - .sort_values("id") - .reset_index() + expected = pd.DataFrame( + { + "id": ["2", "1"], + "mean_absolute_error": [120.011007, 120.011007], + "mean_squared_error": [14562.562359, 14562.562359], + "root_mean_squared_error": [120.675442, 120.675442], + "mean_absolute_percentage_error": [4.80044, 4.80044], + "symmetric_mean_absolute_percentage_error": [4.744332, 4.744332], + }, + dtype="Float64", ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") + expected = expected.sort_values("id") + expected = expected.reset_index(drop=True) else: expected = pd.DataFrame( { @@ -602,6 +598,7 @@ def test_arima_plus_score_series( expected, rtol=0.1, check_index_type=False, + check_dtype=False, ) From 936283890523bf073c5fc03e72c60c29eebebfe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Tue, 2 Sep 2025 16:18:59 +0000 Subject: [PATCH 18/21] exclude ML results from jobless query path --- bigframes/ml/core.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 73b8ba8dbc..da8600c7c4 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -95,7 +95,17 @@ def _apply_ml_tvf( ) result_sql = apply_sql_tvf(input_sql) - df = self._session.read_gbq(result_sql, index_col=index_col_ids) + df = self._session.read_gbq_query( + result_sql, + index_col=index_col_ids, + # Many ML methods use nested JSON, which isn't yet compatible with + # joining local results. Also, there is a chance that the results + # are greater than 10 GB. + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + allow_large_results=True, + ) if df._has_index: df.index.names = index_labels # Restore column labels From d3997ba244cdb3f4d01f240624649bf08e40f1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Tue, 2 Sep 2025 16:34:36 +0000 Subject: [PATCH 19/21] fix unit tests --- bigframes/ml/core.py | 80 ++++++++++++++++++++++++++------ tests/unit/ml/test_golden_sql.py | 31 ++++++++----- 2 files changed, 84 insertions(+), 27 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index da8600c7c4..28f795a0b6 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -45,7 +45,11 @@ def ai_forecast( result_sql = self._sql_generator.ai_forecast( source_sql=input_data.sql, options=options ) - return self._session.read_gbq(result_sql) + + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(result_sql, allow_large_results=True) class BqmlModel(BaseBqml): @@ -169,7 +173,10 @@ def explain_predict( def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame: sql = self._sql_generator.ml_global_explain(struct_options=options) return ( - self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + self._session.read_gbq_query(sql, allow_large_results=True) .sort_values(by="attribution", ascending=False) .set_index("feature") ) @@ -244,26 +251,49 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._sql_generator.ml_forecast(struct_options=options) timestamp_col_name = "forecast_timestamp" index_cols = [timestamp_col_name] - first_col_name = self._session.read_gbq(sql).columns.values[0] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + first_col_name = self._session.read_gbq_query( + sql, allow_large_results=True + ).columns.values[0] if timestamp_col_name != first_col_name: index_cols.append(first_col_name) - return self._session.read_gbq(sql, index_col=index_cols).reset_index() + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=index_cols, allow_large_results=True + ).reset_index() def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._sql_generator.ml_explain_forecast(struct_options=options) timestamp_col_name = "time_series_timestamp" index_cols = [timestamp_col_name] - first_col_name = self._session.read_gbq(sql).columns.values[0] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + first_col_name = self._session.read_gbq_query( + sql, allow_large_results=True + ).columns.values[0] if timestamp_col_name != first_col_name: index_cols.append(first_col_name) - return self._session.read_gbq(sql, index_col=index_cols).reset_index() + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=index_cols, allow_large_results=True + ).reset_index() def evaluate(self, input_data: Optional[bpd.DataFrame] = None): sql = self._sql_generator.ml_evaluate( input_data.sql if (input_data is not None) else None ) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def llm_evaluate( self, @@ -272,25 +302,37 @@ def llm_evaluate( ): sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def arima_evaluate(self, show_all_candidate_models: bool = False): sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def arima_coefficients(self) -> bpd.DataFrame: sql = self._sql_generator.ml_arima_coefficients() - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def centroids(self) -> bpd.DataFrame: assert self._model.model_type == "KMEANS" sql = self._sql_generator.ml_centroids() - return self._session.read_gbq( - sql, index_col=["centroid_id", "feature"] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=["centroid_id", "feature"], allow_large_results=True ).reset_index() def principal_components(self) -> bpd.DataFrame: @@ -298,8 +340,13 @@ def principal_components(self) -> bpd.DataFrame: sql = self._sql_generator.ml_principal_components() - return self._session.read_gbq( - sql, index_col=["principal_component_id", "feature"] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, + index_col=["principal_component_id", "feature"], + allow_large_results=True, ).reset_index() def principal_component_info(self) -> bpd.DataFrame: @@ -307,7 +354,10 @@ def principal_component_info(self) -> bpd.DataFrame: sql = self._sql_generator.ml_principal_component_info() - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: job_config = self._session._prepare_copy_job_config() diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 10fefcc457..7f6843aacf 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -143,9 +143,10 @@ def test_linear_regression_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -154,8 +155,9 @@ def test_linear_regression_score(mock_session, bqml_model, mock_X, mock_y): model._bqml_model = bqml_model model.score(mock_X, mock_y) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))", + allow_large_results=True, ) @@ -167,7 +169,7 @@ def test_logistic_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql", ) @@ -198,9 +200,10 @@ def test_logistic_regression_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -209,8 +212,9 @@ def test_logistic_regression_score(mock_session, bqml_model, mock_X, mock_y): model._bqml_model = bqml_model model.score(mock_X, mock_y) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))", + allow_large_results=True, ) @@ -243,9 +247,10 @@ def test_decomposition_mf_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.RECOMMEND(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -260,8 +265,9 @@ def test_decomposition_mf_score(mock_session, bqml_model): ) model._bqml_model = bqml_model model.score() - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)", + allow_large_results=True, ) @@ -276,6 +282,7 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X): ) model._bqml_model = bqml_model model.score(mock_X) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))", + allow_large_results=True, ) From 7cb3d550910ff7b357548c1849e87b9e932800fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Tue, 2 Sep 2025 17:40:01 +0000 Subject: [PATCH 20/21] add parameter for vector_search too --- bigframes/bigquery/_operations/search.py | 11 +++++++++-- bigframes/operations/ai.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/bigframes/bigquery/_operations/search.py b/bigframes/bigquery/_operations/search.py index 9a1e4b5ac9..72c855f517 100644 --- a/bigframes/bigquery/_operations/search.py +++ b/bigframes/bigquery/_operations/search.py @@ -99,6 +99,7 @@ def vector_search( distance_type: Optional[Literal["euclidean", "cosine", "dot_product"]] = None, fraction_lists_to_search: Optional[float] = None, use_brute_force: Optional[bool] = None, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame: """ Conduct vector search which searches embeddings to find semantically similar entities. @@ -199,6 +200,10 @@ def vector_search( use_brute_force (bool): Determines whether to use brute force search by skipping the vector index if one is available. Default to False. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. + Defaults to ``bpd.options.compute.allow_large_results``. Returns: bigframes.dataframe.DataFrame: A DataFrame containing vector search result. @@ -236,9 +241,11 @@ def vector_search( options=options, ) if index_col_ids is not None: - df = query._session.read_gbq(sql, index_col=index_col_ids) + df = query._session.read_gbq_query( + sql, index_col=index_col_ids, allow_large_results=allow_large_results + ) df.index.names = index_labels else: - df = query._session.read_gbq(sql) + df = query._session.read_gbq_query(sql, allow_large_results=allow_large_results) return df diff --git a/bigframes/operations/ai.py b/bigframes/operations/ai.py index 8c7628059a..ac294b0fbd 100644 --- a/bigframes/operations/ai.py +++ b/bigframes/operations/ai.py @@ -566,6 +566,10 @@ def search( column_to_search=embedding_result_column, query=query_df, top_k=top_k, + # TODO(tswast): set allow_large_results based on Series size. + # If we expect small results, it could be faster to set + # allow_large_results to False. + allow_large_results=True, ) .rename(columns={"content": search_column}) .set_index("index") From 0a70173152892b9cb46e6db9a6c91ac295ff0584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a?= Date: Tue, 2 Sep 2025 17:46:13 +0000 Subject: [PATCH 21/21] fix doctest --- bigframes/bigquery/_operations/search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bigframes/bigquery/_operations/search.py b/bigframes/bigquery/_operations/search.py index 72c855f517..5063fc9118 100644 --- a/bigframes/bigquery/_operations/search.py +++ b/bigframes/bigquery/_operations/search.py @@ -164,12 +164,12 @@ def vector_search( ... query=search_query, ... distance_type="cosine", ... query_column_to_search="another_embedding", - ... top_k=2) + ... top_k=2).sort_values("id") query_id embedding another_embedding id my_embedding distance - 1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181 - 0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013 1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181 + 1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181 0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697 + 0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013 [4 rows x 6 columns]