Skip to content

Commit 5fbc3fa

Browse files
committed
chore: Migrate manhattan_distance_op operator to SQLGlot
1 parent 64995d6 commit 5fbc3fa

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,18 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
338338
return result
339339

340340

341+
@register_binary_op(ops.manhattan_distance_op)
342+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
343+
return sge.Anonymous(
344+
this="ML.DISTANCE",
345+
expressions=[
346+
left.expr,
347+
right.expr,
348+
sge.Literal.string("MANHATTAN"),
349+
],
350+
)
351+
352+
341353
@register_binary_op(ops.mod_op)
342354
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
343355
# In BigQuery returned value has the same sign as X. In pandas, the sign of y is used, so we need to flip the result if sign(x) != sign(y)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float_list_col`,
4+
`numeric_list_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
ML.DISTANCE(`float_list_col`, `float_list_col`, 'MANHATTAN') AS `bfcol_2`,
10+
ML.DISTANCE(`numeric_list_col`, `numeric_list_col`, 'MANHATTAN') AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_2` AS `float_list_col`,
15+
`bfcol_3` AS `numeric_list_col`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,21 @@ def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
339339
snapshot.assert_match(bf_df.sql, "out.sql")
340340

341341

342+
def test_manhattan_distance(repeated_types_df: bpd.DataFrame, snapshot):
343+
col_names = ["float_list_col", "numeric_list_col"]
344+
bf_df = repeated_types_df[col_names]
345+
346+
sql = utils._apply_ops_to_sql(
347+
bf_df,
348+
[
349+
ops.manhattan_distance_op.as_expr("float_list_col", "float_list_col"),
350+
ops.manhattan_distance_op.as_expr("numeric_list_col", "numeric_list_col"),
351+
],
352+
["float_list_col", "numeric_list_col"],
353+
)
354+
snapshot.assert_match(sql, "out.sql")
355+
356+
342357
def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot):
343358
bf_df = scalar_types_df[["int64_col", "bool_col"]]
344359

0 commit comments

Comments
 (0)