diff --git a/bigframes/_config/auth.py b/bigframes/_config/auth.py new file mode 100644 index 0000000000..1574fc4883 --- /dev/null +++ b/bigframes/_config/auth.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from typing import Optional + +import google.auth.credentials +import google.auth.transport.requests +import pydata_google_auth + +_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + +# Put the lock here rather than in BigQueryOptions so that BigQueryOptions +# remains deepcopy-able. +_AUTH_LOCK = threading.Lock() +_cached_credentials: Optional[google.auth.credentials.Credentials] = None +_cached_project_default: Optional[str] = None + + +def get_default_credentials_with_project() -> tuple[ + google.auth.credentials.Credentials, Optional[str] +]: + global _AUTH_LOCK, _cached_credentials, _cached_project_default + + with _AUTH_LOCK: + if _cached_credentials is not None: + return _cached_credentials, _cached_project_default + + _cached_credentials, _cached_project_default = pydata_google_auth.default( + scopes=_SCOPES, use_local_webserver=False + ) + + # Ensure an access token is available. + _cached_credentials.refresh(google.auth.transport.requests.Request()) + + return _cached_credentials, _cached_project_default + + +def reset_default_credentials_and_project(): + global _AUTH_LOCK, _cached_credentials, _cached_project_default + + with _AUTH_LOCK: + _cached_credentials = None + _cached_project_default = None diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 648b69dea7..2456a88073 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -22,6 +22,7 @@ import google.auth.credentials import requests.adapters +import bigframes._config.auth import bigframes._importing import bigframes.enums import bigframes.exceptions as bfe @@ -37,6 +38,7 @@ def _get_validated_location(value: Optional[str]) -> Optional[str]: import bigframes._tools.strings + import bigframes.constants if value is None or value in bigframes.constants.ALL_BIGQUERY_LOCATIONS: return value @@ -141,20 +143,52 @@ def application_name(self, value: Optional[str]): ) self._application_name = value + def _try_set_default_credentials_and_project( + self, + ) -> tuple[google.auth.credentials.Credentials, Optional[str]]: + # Don't fetch credentials or project if credentials is already set. + # If it's set, we've already authenticated, so if the user wants to + # re-auth, they should explicitly reset the credentials. + if self._credentials is not None: + return self._credentials, self._project + + ( + credentials, + credentials_project, + ) = bigframes._config.auth.get_default_credentials_with_project() + self._credentials = credentials + + # Avoid overriding an explicitly set project with a default value. + if self._project is None: + self._project = credentials_project + + return credentials, self._project + @property - def credentials(self) -> Optional[google.auth.credentials.Credentials]: + def credentials(self) -> google.auth.credentials.Credentials: """The OAuth2 credentials to use for this client. + Set to None to force re-authentication. + Returns: None or google.auth.credentials.Credentials: google.auth.credentials.Credentials if exists; otherwise None. """ - return self._credentials + if self._credentials: + return self._credentials + + credentials, _ = self._try_set_default_credentials_and_project() + return credentials @credentials.setter def credentials(self, value: Optional[google.auth.credentials.Credentials]): if self._session_started and self._credentials is not value: raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="credentials")) + + if value is None: + # The user has _explicitly_ asked that we re-authenticate. + bigframes._config.auth.reset_default_credentials_and_project() + self._credentials = value @property @@ -183,7 +217,11 @@ def project(self) -> Optional[str]: None or str: Google Cloud project ID as a string; otherwise None. """ - return self._project + if self._project: + return self._project + + _, project = self._try_set_default_credentials_and_project() + return project @project.setter def project(self, value: Optional[str]): diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index df67e64e9e..080252d9eb 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -49,7 +49,6 @@ import bigframes_vendored.pandas.io.parsers.readers as third_party_pandas_readers import bigframes_vendored.pandas.io.pickle as third_party_pandas_pickle import google.cloud.bigquery as bigquery -import google.cloud.storage as storage # type: ignore import numpy as np import pandas from pandas._typing import ( @@ -1424,7 +1423,7 @@ def _check_file_size(self, filepath: str): if filepath.startswith("gs://"): # GCS file path bucket_name, blob_path = filepath.split("/", 3)[2:] - client = storage.Client() + client = self._clients_provider.storageclient bucket = client.bucket(bucket_name) list_blobs_params = inspect.signature(bucket.list_blobs).parameters diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index d680b94b8a..42bfab2682 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -29,9 +29,10 @@ import google.cloud.bigquery_storage_v1 import google.cloud.functions_v2 import google.cloud.resourcemanager_v3 -import pydata_google_auth +import google.cloud.storage # type: ignore import requests +import bigframes._config import bigframes.constants import bigframes.version @@ -39,7 +40,6 @@ _ENV_DEFAULT_PROJECT = "GOOGLE_CLOUD_PROJECT" _APPLICATION_NAME = f"bigframes/{bigframes.version.__version__} ibis/9.2.0" -_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] # BigQuery is a REST API, which requires the protocol as part of the URL. @@ -50,10 +50,6 @@ _BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "bigquerystorage.{location}.rep.googleapis.com" -def _get_default_credentials_with_project(): - return pydata_google_auth.default(scopes=_SCOPES, use_local_webserver=False) - - def _get_application_names(): apps = [_APPLICATION_NAME] @@ -88,10 +84,8 @@ def __init__( ): credentials_project = None if credentials is None: - credentials, credentials_project = _get_default_credentials_with_project() - - # Ensure an access token is available. - credentials.refresh(google.auth.transport.requests.Request()) + credentials = bigframes._config.options.bigquery.credentials + credentials_project = bigframes._config.options.bigquery.project # Prefer the project in this order: # 1. Project explicitly specified by the user @@ -165,6 +159,9 @@ def __init__( google.cloud.resourcemanager_v3.ProjectsClient ] = None + self._storageclient_lock = threading.Lock() + self._storageclient: Optional[google.cloud.storage.Client] = None + def _create_bigquery_client(self): bq_options = None if "bqclient" in self._client_endpoints_override: @@ -347,3 +344,17 @@ def resourcemanagerclient(self): ) return self._resourcemanagerclient + + @property + def storageclient(self): + with self._storageclient_lock: + if not self._storageclient: + storage_info = google.api_core.client_info.ClientInfo( + user_agent=self._application_name + ) + self._storageclient = google.cloud.storage.Client( + client_info=storage_info, + credentials=self._credentials, + ) + + return self._storageclient diff --git a/tests/unit/pandas/io/test_api.py b/tests/unit/pandas/io/test_api.py index 1e69fa9df3..ba401d1ce6 100644 --- a/tests/unit/pandas/io/test_api.py +++ b/tests/unit/pandas/io/test_api.py @@ -14,11 +14,14 @@ from unittest import mock +import google.cloud.bigquery import pytest import bigframes.dataframe +import bigframes.pandas import bigframes.pandas.io.api as bf_io_api import bigframes.session +import bigframes.session.clients # _read_gbq_colab requires the polars engine. pytest.importorskip("polars") @@ -47,6 +50,49 @@ def test_read_gbq_colab_dry_run_doesnt_call_set_location( mock_set_location.assert_not_called() +@mock.patch("bigframes._config.auth.get_default_credentials_with_project") +@mock.patch("bigframes.core.global_session.with_default_session") +def test_read_gbq_colab_dry_run_doesnt_authenticate_multiple_times( + mock_with_default_session, mock_get_credentials, monkeypatch +): + """ + Ensure that we authenticate too often, which is an expensive operation, + performance-wise (2+ seconds). + """ + bigframes.pandas.close_session() + + mock_get_credentials.return_value = (mock.Mock(), "unit-test-project") + mock_create_bq_client = mock.Mock() + mock_bq_client = mock.create_autospec(google.cloud.bigquery.Client, instance=True) + mock_create_bq_client.return_value = mock_bq_client + mock_query_job = mock.create_autospec(google.cloud.bigquery.QueryJob, instance=True) + type(mock_query_job).schema = mock.PropertyMock(return_value=[]) + mock_query_job._properties = {} + mock_bq_client.query.return_value = mock_query_job + monkeypatch.setattr( + bigframes.session.clients.ClientsProvider, + "_create_bigquery_client", + mock_create_bq_client, + ) + mock_df = mock.create_autospec(bigframes.dataframe.DataFrame) + mock_with_default_session.return_value = mock_df + + query_or_table = "SELECT {param1} AS param1" + sample_pyformat_args = {"param1": "value1"} + bf_io_api._read_gbq_colab( + query_or_table, pyformat_args=sample_pyformat_args, dry_run=True + ) + + mock_with_default_session.assert_not_called() + mock_get_credentials.reset_mock() + + # Repeat the operation so that the credentials would have have been cached. + bf_io_api._read_gbq_colab( + query_or_table, pyformat_args=sample_pyformat_args, dry_run=True + ) + mock_get_credentials.assert_not_called() + + @mock.patch( "bigframes.pandas.io.api._set_default_session_location_if_possible_deferred_query" )