Skip to content

Commit 23b769d

Browse files
committed
chore: Migrate euclidean_distance_op operator to SQLGlot
1 parent da9ba26 commit 23b769d

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,18 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
286286
return result
287287

288288

289+
@register_binary_op(ops.euclidean_distance_op)
290+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
291+
return sge.Anonymous(
292+
this="ML.DISTANCE",
293+
expressions=[
294+
left.expr,
295+
right.expr,
296+
sge.Literal.string("EUCLIDEAN"),
297+
],
298+
)
299+
300+
289301
@register_binary_op(ops.floordiv_op)
290302
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
291303
left_expr = _coerce_bool_to_int(left)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int_list_col` AS `bfcol_0`,
4+
`numeric_list_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
ML.DISTANCE(`bfcol_0`, `bfcol_0`, 'EUCLIDEAN') AS `bfcol_2`,
10+
ML.DISTANCE(`bfcol_1`, `bfcol_1`, 'EUCLIDEAN') AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_2` AS `int_list_col`,
15+
`bfcol_3` AS `numeric_list_col`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
286286
snapshot.assert_match(bf_df.sql, "out.sql")
287287

288288

289+
def test_euclidean_distance(repeated_types_df: bpd.DataFrame, snapshot):
290+
col_names = ["int_list_col", "numeric_list_col"]
291+
bf_df = repeated_types_df[col_names]
292+
293+
sql = utils._apply_ops_to_sql(
294+
bf_df,
295+
[
296+
ops.euclidean_distance_op.as_expr("int_list_col", "int_list_col"),
297+
ops.euclidean_distance_op.as_expr("numeric_list_col", "numeric_list_col"),
298+
],
299+
["int_list_col", "numeric_list_col"],
300+
)
301+
snapshot.assert_match(sql, "out.sql")
302+
303+
289304
def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot):
290305
bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]]
291306

0 commit comments

Comments
 (0)