From 00eee68f5921d1f83e4b77072d1a0d2503f7cea5 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 28 Oct 2025 18:21:33 +0000 Subject: [PATCH] refactor: add rowkey to the sqlglot compiler --- .../sqlglot/expressions/generic_ops.py | 53 ++++++++++++++ .../test_generic_ops/test_row_key/out.sql | 70 +++++++++++++++++++ .../sqlglot/expressions/test_generic_ops.py | 8 +++ 3 files changed, 131 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_row_key/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 7572a1e801..07505855e1 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -159,6 +159,30 @@ def _(*cases_and_outputs: TypedExpr) -> sge.Expression: ) +@register_nary_op(ops.RowKey) +def _(*values: TypedExpr) -> sge.Expression: + # All inputs into hash must be non-null or resulting hash will be null + str_values = [_convert_to_nonnull_string_sqlglot(value) for value in values] + + full_row_hash_p1 = sge.func("FARM_FINGERPRINT", sge.Concat(expressions=str_values)) + + # By modifying value slightly, we get another hash uncorrelated with the first + full_row_hash_p2 = sge.func( + "FARM_FINGERPRINT", sge.Concat(expressions=[*str_values, sge.convert("_")]) + ) + + # Used to disambiguate between identical rows (which will have identical hash) + random_hash_p3 = sge.func("RAND") + + return sge.Concat( + expressions=[ + sge.Cast(this=full_row_hash_p1, to="STRING"), + sge.Cast(this=full_row_hash_p2, to="STRING"), + sge.Cast(this=random_hash_p3, to="STRING"), + ] + ) + + # Helper functions def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: from_type = expr.dtype @@ -218,3 +242,32 @@ def _cast(expr: sge.Expression, to: str, safe: bool): return sge.TryCast(this=expr, to=to) else: return sge.Cast(this=expr, to=to) + + +def _convert_to_nonnull_string_sqlglot(expr: TypedExpr) -> sge.Expression: + col_type = expr.dtype + sg_expr = expr.expr + + if col_type == dtypes.STRING_DTYPE: + result = sg_expr + elif ( + dtypes.is_numeric(col_type) + or dtypes.is_time_or_date_like(col_type) + or col_type == dtypes.BYTES_DTYPE + ): + result = sge.Cast(this=sg_expr, to="STRING") + elif col_type == dtypes.GEO_DTYPE: + result = sge.func("ST_ASTEXT", sg_expr) + else: + # TO_JSON_STRING works with all data types, but isn't the most efficient + # Needed for JSON, STRUCT and ARRAY datatypes + result = sge.func("TO_JSON_STRING", sg_expr) + + # Escape backslashes and use backslash as delineator + escaped = sge.func( + "REPLACE", + sge.func("COALESCE", result, sge.convert("")), + sge.convert("\\"), + sge.convert("\\\\"), + ) + return sge.Concat(expressions=[sge.convert("\\"), escaped]) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_row_key/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_row_key/out.sql new file mode 100644 index 0000000000..080e35f68e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_row_key/out.sql @@ -0,0 +1,70 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `date_col` AS `bfcol_2`, + `datetime_col` AS `bfcol_3`, + `geography_col` AS `bfcol_4`, + `int64_col` AS `bfcol_5`, + `int64_too` AS `bfcol_6`, + `numeric_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + `rowindex` AS `bfcol_9`, + `rowindex_2` AS `bfcol_10`, + `string_col` AS `bfcol_11`, + `time_col` AS `bfcol_12`, + `timestamp_col` AS `bfcol_13`, + `duration_col` AS `bfcol_14` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CONCAT( + CAST(FARM_FINGERPRINT( + CONCAT( + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_0` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_1` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_2` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_3` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(ST_ASTEXT(`bfcol_4`), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_5` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_6` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_7` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_8` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_10` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(`bfcol_11`, ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_12` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_13` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_14` AS STRING), ''), '\\', '\\\\')) + ) + ) AS STRING), + CAST(FARM_FINGERPRINT( + CONCAT( + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_0` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_1` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_2` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_3` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(ST_ASTEXT(`bfcol_4`), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_5` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_6` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_7` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_8` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_9` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_10` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(`bfcol_11`, ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_12` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_13` AS STRING), ''), '\\', '\\\\')), + CONCAT('\\', REPLACE(COALESCE(CAST(`bfcol_14` AS STRING), ''), '\\', '\\\\')), + '_' + ) + ) AS STRING), + CAST(RAND() AS STRING) + ) AS `bfcol_31` + FROM `bfcte_0` +) +SELECT + `bfcol_31` AS `row_key` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 075416d664..fd9732bf89 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -261,6 +261,14 @@ def test_notnull(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_row_key(scalar_types_df: bpd.DataFrame, snapshot): + column_ids = (col for col in scalar_types_df._block.expr.column_ids) + sql = utils._apply_unary_ops( + scalar_types_df, [ops.RowKey().as_expr(*column_ids)], ["row_key"] + ) + snapshot.assert_match(sql, "out.sql") + + def test_sql_scalar_op(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["bool_col", "bytes_col"]] sql = utils._apply_nary_op(