Skip to content

Commit 54393fb

Browse files
Merge remote-tracking branch 'github/main' into plot_kinds
2 parents 5b9bf8c + 7600001 commit 54393fb

39 files changed

+2002
-3626
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
31+
str,
3132
series.Series,
3233
pd.Series,
3334
List[Union[str, series.Series, pd.Series]],
@@ -73,7 +74,7 @@ def generate(
7374
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
7475
7576
Args:
76-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
77+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
7778
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
7879
or pandas Series.
7980
connection_id (str, optional):
@@ -165,7 +166,7 @@ def generate_bool(
165166
Name: result, dtype: boolean
166167
167168
Args:
168-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
169+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
169170
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
170171
or pandas Series.
171172
connection_id (str, optional):
@@ -240,7 +241,7 @@ def generate_int(
240241
Name: result, dtype: Int64
241242
242243
Args:
243-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
244+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
244245
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
245246
or pandas Series.
246247
connection_id (str, optional):
@@ -315,7 +316,7 @@ def generate_double(
315316
Name: result, dtype: Float64
316317
317318
Args:
318-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
319+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
319320
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
320321
or pandas Series.
321322
connection_id (str, optional):
@@ -386,7 +387,7 @@ def if_(
386387
dtype: string
387388
388389
Args:
389-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
390+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
390391
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
391392
or pandas Series.
392393
connection_id (str, optional):
@@ -433,7 +434,7 @@ def classify(
433434
[2 rows x 2 columns]
434435
435436
Args:
436-
input (Series | List[str|Series] | Tuple[str|Series, ...]):
437+
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
437438
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
438439
or pandas Series.
439440
categories (tuple[str, ...] | list[str]):
@@ -482,7 +483,7 @@ def score(
482483
dtype: Float64
483484
484485
Args:
485-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
486+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
486487
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
487488
or pandas Series.
488489
connection_id (str, optional):
@@ -514,9 +515,12 @@ def _separate_context_and_series(
514515
Input: ("str1", series1, "str2", "str3", series2)
515516
Output: ["str1", None, "str2", "str3", None], [series1, series2]
516517
"""
517-
if not isinstance(prompt, (list, tuple, series.Series)):
518+
if not isinstance(prompt, (str, list, tuple, series.Series)):
518519
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
519520

521+
if isinstance(prompt, str):
522+
return [None], [series.Series([prompt])]
523+
520524
if isinstance(prompt, series.Series):
521525
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
522526
# Multi-model support

bigframes/blob/_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _create_udf(self):
9999
project=None,
100100
timeout=None,
101101
query_with_job=True,
102+
publisher=self._session._publisher,
102103
)
103104

104105
return udf_name

bigframes/core/events.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
import datetime
19+
import threading
20+
from typing import Any, Callable, Optional, Set
21+
import uuid
22+
23+
import google.cloud.bigquery._job_helpers
24+
import google.cloud.bigquery.job.query
25+
import google.cloud.bigquery.table
26+
27+
import bigframes.session.executor
28+
29+
30+
class Subscriber:
31+
def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher):
32+
self._publisher = publisher
33+
self._callback = callback
34+
self._subscriber_id = uuid.uuid4()
35+
36+
def __call__(self, *args, **kwargs):
37+
return self._callback(*args, **kwargs)
38+
39+
def __hash__(self) -> int:
40+
return hash(self._subscriber_id)
41+
42+
def __eq__(self, value: object):
43+
if not isinstance(value, Subscriber):
44+
return NotImplemented
45+
return value._subscriber_id == self._subscriber_id
46+
47+
def close(self):
48+
self._publisher.unsubscribe(self)
49+
del self._publisher
50+
del self._callback
51+
52+
def __enter__(self):
53+
return self
54+
55+
def __exit__(self, exc_type, exc_value, traceback):
56+
if exc_value is not None:
57+
self(
58+
UnknownErrorEvent(
59+
exc_type=exc_type,
60+
exc_value=exc_value,
61+
traceback=traceback,
62+
)
63+
)
64+
self.close()
65+
66+
67+
class Publisher:
68+
def __init__(self):
69+
self._subscribers_lock = threading.Lock()
70+
self._subscribers: Set[Subscriber] = set()
71+
72+
def subscribe(self, callback: Callable[[Event], None]) -> Subscriber:
73+
# TODO(b/448176657): figure out how to handle subscribers/publishers in
74+
# a background thread. Maybe subscribers should be thread-local?
75+
subscriber = Subscriber(callback, publisher=self)
76+
with self._subscribers_lock:
77+
self._subscribers.add(subscriber)
78+
return subscriber
79+
80+
def unsubscribe(self, subscriber: Subscriber):
81+
with self._subscribers_lock:
82+
self._subscribers.remove(subscriber)
83+
84+
def publish(self, event: Event):
85+
with self._subscribers_lock:
86+
for subscriber in self._subscribers:
87+
subscriber(event)
88+
89+
90+
class Event:
91+
pass
92+
93+
94+
@dataclasses.dataclass(frozen=True)
95+
class SessionClosed(Event):
96+
session_id: str
97+
98+
99+
class ExecutionStarted(Event):
100+
pass
101+
102+
103+
class ExecutionRunning(Event):
104+
pass
105+
106+
107+
@dataclasses.dataclass(frozen=True)
108+
class ExecutionFinished(Event):
109+
result: Optional[bigframes.session.executor.ExecuteResult] = None
110+
111+
112+
@dataclasses.dataclass(frozen=True)
113+
class UnknownErrorEvent(Event):
114+
exc_type: Any
115+
exc_value: Any
116+
traceback: Any
117+
118+
119+
@dataclasses.dataclass(frozen=True)
120+
class BigQuerySentEvent(ExecutionRunning):
121+
"""Query sent to BigQuery."""
122+
123+
query: str
124+
billing_project: Optional[str] = None
125+
location: Optional[str] = None
126+
job_id: Optional[str] = None
127+
request_id: Optional[str] = None
128+
129+
@classmethod
130+
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent):
131+
return cls(
132+
query=event.query,
133+
billing_project=event.billing_project,
134+
location=event.location,
135+
job_id=event.job_id,
136+
request_id=event.request_id,
137+
)
138+
139+
140+
@dataclasses.dataclass(frozen=True)
141+
class BigQueryRetryEvent(ExecutionRunning):
142+
"""Query sent another time because the previous attempt failed."""
143+
144+
query: str
145+
billing_project: Optional[str] = None
146+
location: Optional[str] = None
147+
job_id: Optional[str] = None
148+
request_id: Optional[str] = None
149+
150+
@classmethod
151+
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent):
152+
return cls(
153+
query=event.query,
154+
billing_project=event.billing_project,
155+
location=event.location,
156+
job_id=event.job_id,
157+
request_id=event.request_id,
158+
)
159+
160+
161+
@dataclasses.dataclass(frozen=True)
162+
class BigQueryReceivedEvent(ExecutionRunning):
163+
"""Query received and acknowledged by the BigQuery API."""
164+
165+
billing_project: Optional[str] = None
166+
location: Optional[str] = None
167+
job_id: Optional[str] = None
168+
statement_type: Optional[str] = None
169+
state: Optional[str] = None
170+
query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]] = None
171+
created: Optional[datetime.datetime] = None
172+
started: Optional[datetime.datetime] = None
173+
ended: Optional[datetime.datetime] = None
174+
175+
@classmethod
176+
def from_bqclient(
177+
cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent
178+
):
179+
return cls(
180+
billing_project=event.billing_project,
181+
location=event.location,
182+
job_id=event.job_id,
183+
statement_type=event.statement_type,
184+
state=event.state,
185+
query_plan=event.query_plan,
186+
created=event.created,
187+
started=event.started,
188+
ended=event.ended,
189+
)
190+
191+
192+
@dataclasses.dataclass(frozen=True)
193+
class BigQueryFinishedEvent(ExecutionRunning):
194+
"""Query finished successfully."""
195+
196+
billing_project: Optional[str] = None
197+
location: Optional[str] = None
198+
query_id: Optional[str] = None
199+
job_id: Optional[str] = None
200+
destination: Optional[google.cloud.bigquery.table.TableReference] = None
201+
total_rows: Optional[int] = None
202+
total_bytes_processed: Optional[int] = None
203+
slot_millis: Optional[int] = None
204+
created: Optional[datetime.datetime] = None
205+
started: Optional[datetime.datetime] = None
206+
ended: Optional[datetime.datetime] = None
207+
208+
@classmethod
209+
def from_bqclient(
210+
cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent
211+
):
212+
return cls(
213+
billing_project=event.billing_project,
214+
location=event.location,
215+
query_id=event.query_id,
216+
job_id=event.job_id,
217+
destination=event.destination,
218+
total_rows=event.total_rows,
219+
total_bytes_processed=event.total_bytes_processed,
220+
slot_millis=event.slot_millis,
221+
created=event.created,
222+
started=event.started,
223+
ended=event.ended,
224+
)
225+
226+
227+
@dataclasses.dataclass(frozen=True)
228+
class BigQueryUnknownEvent(ExecutionRunning):
229+
"""Got unknown event from the BigQuery client library."""
230+
231+
# TODO: should we just skip sending unknown events?
232+
233+
event: object
234+
235+
@classmethod
236+
def from_bqclient(cls, event):
237+
return cls(event)

bigframes/core/global_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414

1515
"""Utilities for managing a default, globally available Session object."""
1616

17+
from __future__ import annotations
18+
1719
import threading
1820
import traceback
19-
from typing import Callable, Optional, TypeVar
21+
from typing import Callable, Optional, TYPE_CHECKING, TypeVar
2022
import warnings
2123

2224
import google.auth.exceptions
2325

24-
import bigframes._config
2526
import bigframes.exceptions as bfe
26-
import bigframes.session
27+
28+
if TYPE_CHECKING:
29+
import bigframes.session
2730

2831
_global_session: Optional[bigframes.session.Session] = None
2932
_global_session_lock = threading.Lock()
@@ -56,6 +59,9 @@ def close_session() -> None:
5659
Returns:
5760
None
5861
"""
62+
# Avoid troubles with circular imports.
63+
import bigframes._config
64+
5965
global _global_session, _global_session_lock, _global_session_state
6066

6167
if bigframes._config.options.is_bigquery_thread_local:
@@ -88,6 +94,10 @@ def get_global_session():
8894
8995
Creates the global session if it does not exist.
9096
"""
97+
# Avoid troubles with circular imports.
98+
import bigframes._config
99+
import bigframes.session
100+
91101
global _global_session, _global_session_lock, _global_session_state
92102

93103
if bigframes._config.options.is_bigquery_thread_local:

0 commit comments

Comments
 (0)