Skip to content

Commit cf62489

Browse files
committed
chore: Migrate cosine_distance_op operator to SQLGlot
1 parent 5663d2a commit cf62489

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ def _(expr: TypedExpr) -> sge.Expression:
118118
)
119119

120120

121+
@register_binary_op(ops.cosine_distance_op)
122+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
123+
return sge.Anonymous(
124+
this="ML.DISTANCE",
125+
expressions=[
126+
left.expr,
127+
right.expr,
128+
sge.Literal.string("COSINE"),
129+
],
130+
)
131+
132+
121133
@register_unary_op(ops.exp_op)
122134
def _(expr: TypedExpr) -> sge.Expression:
123135
return sge.Case(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int_list_col` AS `bfcol_0`,
4+
`float_list_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
ML.DISTANCE(`bfcol_0`, `bfcol_0`, 'COSINE') AS `bfcol_2`,
10+
ML.DISTANCE(`bfcol_1`, `bfcol_1`, 'COSINE') AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_2` AS `int_list_col`,
15+
`bfcol_3` AS `float_list_col`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ def test_cosh(scalar_types_df: bpd.DataFrame, snapshot):
103103
snapshot.assert_match(sql, "out.sql")
104104

105105

106+
def test_cosine_distance(repeated_types_df: bpd.DataFrame, snapshot):
107+
col_names = ["int_list_col", "float_list_col"]
108+
bf_df = repeated_types_df[col_names]
109+
110+
sql = utils._apply_ops_to_sql(
111+
bf_df,
112+
[
113+
ops.cosine_distance_op.as_expr("int_list_col", "int_list_col"),
114+
ops.cosine_distance_op.as_expr("float_list_col", "float_list_col"),
115+
],
116+
["int_list_col", "float_list_col"],
117+
)
118+
snapshot.assert_match(sql, "out.sql")
119+
120+
106121
def test_exp(scalar_types_df: bpd.DataFrame, snapshot):
107122
col_name = "float64_col"
108123
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)