Skip to content

Commit 00eee68

Browse files
committed
refactor: add rowkey to the sqlglot compiler
1 parent 4c98c95 commit 00eee68

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,30 @@ def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
159159
)
160160

161161

162+
@register_nary_op(ops.RowKey)
163+
def _(*values: TypedExpr) -> sge.Expression:
164+
# All inputs into hash must be non-null or resulting hash will be null
165+
str_values = [_convert_to_nonnull_string_sqlglot(value) for value in values]
166+
167+
full_row_hash_p1 = sge.func("FARM_FINGERPRINT", sge.Concat(expressions=str_values))
168+
169+
# By modifying value slightly, we get another hash uncorrelated with the first
170+
full_row_hash_p2 = sge.func(
171+
"FARM_FINGERPRINT", sge.Concat(expressions=[*str_values, sge.convert("_")])
172+
)
173+
174+
# Used to disambiguate between identical rows (which will have identical hash)
175+
random_hash_p3 = sge.func("RAND")
176+
177+
return sge.Concat(
178+
expressions=[
179+
sge.Cast(this=full_row_hash_p1, to="STRING"),
180+
sge.Cast(this=full_row_hash_p2, to="STRING"),
181+
sge.Cast(this=random_hash_p3, to="STRING"),
182+
]
183+
)
184+
185+
162186
# Helper functions
163187
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
164188
from_type = expr.dtype
@@ -218,3 +242,32 @@ def _cast(expr: sge.Expression, to: str, safe: bool):
218242
return sge.TryCast(this=expr, to=to)
219243
else:
220244
return sge.Cast(this=expr, to=to)
245+
246+
247+
def _convert_to_nonnull_string_sqlglot(expr: TypedExpr) -> sge.Expression:
248+
col_type = expr.dtype
249+
sg_expr = expr.expr
250+
251+
if col_type == dtypes.STRING_DTYPE:
252+
result = sg_expr
253+
elif (
254+
dtypes.is_numeric(col_type)
255+
or dtypes.is_time_or_date_like(col_type)
256+
or col_type == dtypes.BYTES_DTYPE
257+
):
258+
result = sge.Cast(this=sg_expr, to="STRING")
259+
elif col_type == dtypes.GEO_DTYPE:
260+
result = sge.func("ST_ASTEXT", sg_expr)
261+
else:
262+
# TO_JSON_STRING works with all data types, but isn't the most efficient
263+
# Needed for JSON, STRUCT and ARRAY datatypes
264+
result = sge.func("TO_JSON_STRING", sg_expr)
265+
266+
# Escape backslashes and use backslash as delineator
267+
escaped = sge.func(
268+
"REPLACE",
269+
sge.func("COALESCE", result, sge.convert("")),
270+
sge.convert("\\"),
271+
sge.convert("\\\\"),
272+
)
273+
return sge.Concat(expressions=[sge.convert("\\"), escaped])
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`bytes_col` AS `bfcol_1`,
5+
`date_col` AS `bfcol_2`,
6+
`datetime_col` AS `bfcol_3`,
7+
`geography_col` AS `bfcol_4`,
8+
`int64_col` AS `bfcol_5`,
9+
`int64_too` AS `bfcol_6`,
10+
`numeric_col` AS `bfcol_7`,
11+
`float64_col` AS `bfcol_8`,
12+
`rowindex` AS `bfcol_9`,
13+
`rowindex_2` AS `bfcol_10`,
14+
`string_col` AS `bfcol_11`,
15+
`time_col` AS `bfcol_12`,
16+
`timestamp_col` AS `bfcol_13`,
17+
`duration_col` AS `bfcol_14`
18+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
19+
), `bfcte_1` AS (
20+
SELECT
21+
*,
22+
CONCAT(
23+
CAST(FARM_FINGERPRINT(
24+
CONCAT(
25+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')),
26+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_0` AS STRING), ''), '\\', '\\\\')),
27+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_1` AS STRING), ''), '\\', '\\\\')),
28+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_2` AS STRING), ''), '\\', '\\\\')),
29+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_3` AS STRING), ''), '\\', '\\\\')),
30+
CONCAT('\\', REPLACE(COALESCE(ST_ASTEXT(`bfcol_4`), ''), '\\', '\\\\')),
31+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_5` AS STRING), ''), '\\', '\\\\')),
32+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_6` AS STRING), ''), '\\', '\\\\')),
33+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_7` AS STRING), ''), '\\', '\\\\')),
34+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_8` AS STRING), ''), '\\', '\\\\')),
35+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')),
36+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_10` AS STRING), ''), '\\', '\\\\')),
37+
CONCAT('\\', REPLACE(COALESCE(`bfcol_11`, ''), '\\', '\\\\')),
38+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_12` AS STRING), ''), '\\', '\\\\')),
39+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_13` AS STRING), ''), '\\', '\\\\')),
40+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_14` AS STRING), ''), '\\', '\\\\'))
41+
)
42+
) AS STRING),
43+
CAST(FARM_FINGERPRINT(
44+
CONCAT(
45+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')),
46+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_0` AS STRING), ''), '\\', '\\\\')),
47+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_1` AS STRING), ''), '\\', '\\\\')),
48+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_2` AS STRING), ''), '\\', '\\\\')),
49+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_3` AS STRING), ''), '\\', '\\\\')),
50+
CONCAT('\\', REPLACE(COALESCE(ST_ASTEXT(`bfcol_4`), ''), '\\', '\\\\')),
51+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_5` AS STRING), ''), '\\', '\\\\')),
52+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_6` AS STRING), ''), '\\', '\\\\')),
53+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_7` AS STRING), ''), '\\', '\\\\')),
54+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_8` AS STRING), ''), '\\', '\\\\')),
55+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')),
56+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_10` AS STRING), ''), '\\', '\\\\')),
57+
CONCAT('\\', REPLACE(COALESCE(`bfcol_11`, ''), '\\', '\\\\')),
58+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_12` AS STRING), ''), '\\', '\\\\')),
59+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_13` AS STRING), ''), '\\', '\\\\')),
60+
CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_14` AS STRING), ''), '\\', '\\\\')),
61+
'_'
62+
)
63+
) AS STRING),
64+
CAST(RAND() AS STRING)
65+
) AS `bfcol_31`
66+
FROM `bfcte_0`
67+
)
68+
SELECT
69+
`bfcol_31` AS `row_key`
70+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,14 @@ def test_notnull(scalar_types_df: bpd.DataFrame, snapshot):
261261
snapshot.assert_match(sql, "out.sql")
262262

263263

264+
def test_row_key(scalar_types_df: bpd.DataFrame, snapshot):
265+
column_ids = (col for col in scalar_types_df._block.expr.column_ids)
266+
sql = utils._apply_unary_ops(
267+
scalar_types_df, [ops.RowKey().as_expr(*column_ids)], ["row_key"]
268+
)
269+
snapshot.assert_match(sql, "out.sql")
270+
271+
264272
def test_sql_scalar_op(scalar_types_df: bpd.DataFrame, snapshot):
265273
bf_df = scalar_types_df[["bool_col", "bytes_col"]]
266274
sql = utils._apply_nary_op(

0 commit comments

Comments
 (0)