From cf62489ddf164f66d0090faa1ed713d347d99ab5 Mon Sep 17 00:00:00 2001 From: jialuo Date: Thu, 6 Nov 2025 21:18:37 +0000 Subject: [PATCH] chore: Migrate cosine_distance_op operator to SQLGlot --- .../compile/sqlglot/expressions/numeric_ops.py | 12 ++++++++++++ .../test_cosine_distance/out.sql | 16 ++++++++++++++++ .../sqlglot/expressions/test_numeric_ops.py | 15 +++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_cosine_distance/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index afc0d9d01c..6492ca5683 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -118,6 +118,18 @@ def _(expr: TypedExpr) -> sge.Expression: ) +@register_binary_op(ops.cosine_distance_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Anonymous( + this="ML.DISTANCE", + expressions=[ + left.expr, + right.expr, + sge.Literal.string("COSINE"), + ], + ) + + @register_unary_op(ops.exp_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Case( diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_cosine_distance/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_cosine_distance/out.sql new file mode 100644 index 0000000000..eb46a16a83 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_cosine_distance/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int_list_col` AS `bfcol_0`, + `float_list_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`repeated_types` +), `bfcte_1` AS ( + SELECT + *, + ML.DISTANCE(`bfcol_0`, `bfcol_0`, 'COSINE') AS `bfcol_2`, + ML.DISTANCE(`bfcol_1`, `bfcol_1`, 'COSINE') AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `int_list_col`, + `bfcol_3` AS `float_list_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index c66fe15c16..09c8326f7f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -103,6 +103,21 @@ def test_cosh(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_cosine_distance(repeated_types_df: bpd.DataFrame, snapshot): + col_names = ["int_list_col", "float_list_col"] + bf_df = repeated_types_df[col_names] + + sql = utils._apply_ops_to_sql( + bf_df, + [ + ops.cosine_distance_op.as_expr("int_list_col", "int_list_col"), + ops.cosine_distance_op.as_expr("float_list_col", "float_list_col"), + ], + ["int_list_col", "float_list_col"], + ) + snapshot.assert_match(sql, "out.sql") + + def test_exp(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]]