Skip to content

Commit 0d793f9

Browse files
feat: Implement AI.GENERATE_EMBEDDING wrapper
This change implements the `bigframes.bigquery.ai.generate_embedding` function, which wraps the BigQuery `AI.GENERATE_EMBEDDING` TVF. It supports: - Generating embeddings from DataFrames and Series. - Generating embeddings from pandas DataFrames and Series. - Specifying model name and arguments like `output_dimensionality`, `start_second`, `end_second`, and `interval_seconds`. The function is exposed in `bigframes.bigquery.ai`. Unit tests have been added to verify the generated SQL and argument mapping.
1 parent 173b83d commit 0d793f9

File tree

2 files changed

+221
-1
lines changed

2 files changed

+221
-1
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from __future__ import annotations
2020

2121
import json
22-
from typing import Any, Iterable, List, Literal, Mapping, Tuple, Union
22+
from typing import Any, Iterable, List, Literal, Mapping, Optional, Tuple, Union
2323

2424
import pandas as pd
2525

@@ -387,6 +387,91 @@ def generate_double(
387387
return series_list[0]._apply_nary_op(operator, series_list[1:])
388388

389389

390+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
391+
def generate_embedding(
392+
model_name: str,
393+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
394+
*,
395+
output_dimensionality: Optional[int] = None,
396+
start_second: Optional[float] = None,
397+
end_second: Optional[float] = None,
398+
interval_seconds: Optional[float] = None,
399+
) -> dataframe.DataFrame:
400+
"""
401+
Creates embeddings that describe an entity—for example, a piece of text or an image.
402+
403+
**Examples:**
404+
405+
>>> import bigframes.pandas as bpd
406+
>>> import bigframes.bigquery as bbq
407+
>>> df = bpd.DataFrame({"content": ["apple", "bear", "pear"]})
408+
>>> bbq.ai.generate_embedding(
409+
... "project.dataset.model_name",
410+
... df
411+
... ) # doctest: +SKIP
412+
413+
Args:
414+
model_name (str):
415+
The name of a remote model over a Vertex AI multimodalembedding@001 model.
416+
data (DataFrame or Series):
417+
The data to generate embeddings for. If a Series is provided, it is treated as the 'content' column.
418+
If a DataFrame is provided, it must contain a 'content' column, or you must rename the column you wish to embed to 'content'.
419+
output_dimensionality (int, optional):
420+
The number of dimensions to use when generating embeddings. Valid values are 128, 256, 512, and 1408. The default value is 1408.
421+
start_second (float, optional):
422+
The second in the video at which to start the embedding. The default value is 0.
423+
end_second (float, optional):
424+
The second in the video at which to end the embedding. The default value is 120.
425+
interval_seconds (float, optional):
426+
The interval to use when creating embeddings. The default value is 16.
427+
428+
Returns:
429+
bigframes.dataframe.DataFrame:
430+
A new DataFrame with the generated embeddings. It contains the input table columns and the following columns:
431+
* "embedding": an ARRAY<FLOAT64> value that contains the generated embedding vector.
432+
* "status": a STRING value that contains the API response status for the corresponding row.
433+
* "video_start_sec": for video content, an INT64 value that contains the starting second.
434+
* "video_end_sec": for video content, an INT64 value that contains the ending second.
435+
"""
436+
if isinstance(data, (pd.DataFrame, pd.Series)):
437+
data = bpd.read_pandas(data)
438+
439+
if isinstance(data, series.Series):
440+
# Rename series to 'content' and convert to DataFrame
441+
data_df = data.rename("content").to_frame()
442+
elif isinstance(data, dataframe.DataFrame):
443+
data_df = data
444+
else:
445+
raise ValueError(f"Unsupported data type: {type(data)}")
446+
447+
# We need to get the SQL for the input data to pass as a subquery to the TVF
448+
source_sql = data_df.sql
449+
450+
struct_fields = []
451+
if output_dimensionality is not None:
452+
struct_fields.append(f"{output_dimensionality} AS output_dimensionality")
453+
if start_second is not None:
454+
struct_fields.append(f"{start_second} AS start_second")
455+
if end_second is not None:
456+
struct_fields.append(f"{end_second} AS end_second")
457+
if interval_seconds is not None:
458+
struct_fields.append(f"{interval_seconds} AS interval_seconds")
459+
460+
struct_args = ", ".join(struct_fields)
461+
462+
# Construct the TVF query
463+
query = f"""
464+
SELECT *
465+
FROM AI.GENERATE_EMBEDDING(
466+
MODEL `{model_name}`,
467+
({source_sql}),
468+
STRUCT({struct_args})
469+
)
470+
"""
471+
472+
return data_df._session.read_gbq(query)
473+
474+
390475
@log_adapter.method_logger(custom_base_name="bigquery_ai")
391476
def if_(
392477
prompt: PROMPT_TYPE,

tests/unit/bigquery/test_ai.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
import pandas as pd
18+
import pytest
19+
20+
import bigframes.bigquery._operations.ai as ai_ops
21+
import bigframes.dataframe
22+
import bigframes.series
23+
import bigframes.session
24+
25+
26+
@pytest.fixture
27+
def mock_session():
28+
return mock.create_autospec(spec=bigframes.session.Session)
29+
30+
31+
@pytest.fixture
32+
def mock_dataframe(mock_session):
33+
df = mock.create_autospec(spec=bigframes.dataframe.DataFrame)
34+
df._session = mock_session
35+
df.sql = "SELECT * FROM my_table"
36+
return df
37+
38+
39+
@pytest.fixture
40+
def mock_series(mock_session):
41+
s = mock.create_autospec(spec=bigframes.series.Series)
42+
s._session = mock_session
43+
# Mock to_frame to return a mock dataframe
44+
df = mock.create_autospec(spec=bigframes.dataframe.DataFrame)
45+
df._session = mock_session
46+
df.sql = "SELECT my_col AS content FROM my_table"
47+
s.rename.return_value.to_frame.return_value = df
48+
return s
49+
50+
51+
def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
52+
model_name = "project.dataset.model"
53+
54+
ai_ops.generate_embedding(
55+
model_name,
56+
mock_dataframe,
57+
output_dimensionality=256,
58+
)
59+
60+
mock_session.read_gbq.assert_called_once()
61+
query = mock_session.read_gbq.call_args[0][0]
62+
63+
# Normalize whitespace for comparison
64+
query = " ".join(query.split())
65+
66+
expected_part_1 = "SELECT * FROM AI.GENERATE_EMBEDDING("
67+
expected_part_2 = f"MODEL `{model_name}`,"
68+
expected_part_3 = "(SELECT * FROM my_table),"
69+
expected_part_4 = "STRUCT(256 AS output_dimensionality)"
70+
71+
assert expected_part_1 in query
72+
assert expected_part_2 in query
73+
assert expected_part_3 in query
74+
assert expected_part_4 in query
75+
76+
77+
def test_generate_embedding_with_series(mock_series, mock_session):
78+
model_name = "project.dataset.model"
79+
80+
ai_ops.generate_embedding(
81+
model_name,
82+
mock_series,
83+
start_second=0.0,
84+
end_second=10.0,
85+
interval_seconds=5.0
86+
)
87+
88+
mock_series.rename.assert_called_with("content")
89+
mock_series.rename.return_value.to_frame.assert_called_once()
90+
91+
mock_session.read_gbq.assert_called_once()
92+
query = mock_session.read_gbq.call_args[0][0]
93+
query = " ".join(query.split())
94+
95+
assert f"MODEL `{model_name}`" in query
96+
assert "(SELECT my_col AS content FROM my_table)" in query
97+
assert "STRUCT(0.0 AS start_second, 10.0 AS end_second, 5.0 AS interval_seconds)" in query
98+
99+
100+
def test_generate_embedding_defaults(mock_dataframe, mock_session):
101+
model_name = "project.dataset.model"
102+
103+
ai_ops.generate_embedding(
104+
model_name,
105+
mock_dataframe,
106+
)
107+
108+
mock_session.read_gbq.assert_called_once()
109+
query = mock_session.read_gbq.call_args[0][0]
110+
query = " ".join(query.split())
111+
112+
assert f"MODEL `{model_name}`" in query
113+
assert "STRUCT()" in query
114+
115+
116+
@mock.patch("bigframes.pandas.read_pandas")
117+
def test_generate_embedding_with_pandas_dataframe(read_pandas_mock, mock_dataframe, mock_session):
118+
# This tests that pandas input path works and calls read_pandas
119+
model_name = "project.dataset.model"
120+
121+
# Mock return value of read_pandas to be a BigFrames DataFrame
122+
read_pandas_mock.return_value = mock_dataframe
123+
124+
pandas_df = pd.DataFrame({"content": ["test"]})
125+
126+
ai_ops.generate_embedding(
127+
model_name,
128+
pandas_df,
129+
)
130+
131+
read_pandas_mock.assert_called_once()
132+
# Check that read_pandas was called with something (the pandas df)
133+
assert read_pandas_mock.call_args[0][0] is pandas_df
134+
135+
mock_session.read_gbq.assert_called_once()

0 commit comments

Comments
 (0)