diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index 36e2973565..e33702b08c 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -305,6 +305,18 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: return result +@register_binary_op(ops.euclidean_distance_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Anonymous( + this="ML.DISTANCE", + expressions=[ + left.expr, + right.expr, + sge.Literal.string("EUCLIDEAN"), + ], + ) + + @register_binary_op(ops.floordiv_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_euclidean_distance/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_euclidean_distance/out.sql new file mode 100644 index 0000000000..3327a99f4b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_euclidean_distance/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int_list_col`, + `numeric_list_col` + FROM `bigframes-dev`.`sqlglot_test`.`repeated_types` +), `bfcte_1` AS ( + SELECT + *, + ML.DISTANCE(`int_list_col`, `int_list_col`, 'EUCLIDEAN') AS `bfcol_2`, + ML.DISTANCE(`numeric_list_col`, `numeric_list_col`, 'EUCLIDEAN') AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `int_list_col`, + `bfcol_3` AS `numeric_list_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 06731bcbfa..c58ce9e2f1 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -315,6 +315,21 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(bf_df.sql, "out.sql") +def test_euclidean_distance(repeated_types_df: bpd.DataFrame, snapshot): + col_names = ["int_list_col", "numeric_list_col"] + bf_df = repeated_types_df[col_names] + + sql = utils._apply_ops_to_sql( + bf_df, + [ + ops.euclidean_distance_op.as_expr("int_list_col", "int_list_col"), + ops.euclidean_distance_op.as_expr("numeric_list_col", "numeric_list_col"), + ], + ["int_list_col", "numeric_list_col"], + ) + snapshot.assert_match(sql, "out.sql") + + def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]]