|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import sqlglot as sg |
17 | 18 | import sqlglot.expressions as sge |
18 | 19 |
|
19 | 20 | from bigframes import dtypes |
@@ -80,6 +81,16 @@ def _(expr: TypedExpr) -> sge.Expression: |
80 | 81 | return sge.BitwiseNot(this=sge.paren(expr.expr)) |
81 | 82 |
|
82 | 83 |
|
| 84 | +@register_nary_op(ops.SqlScalarOp, pass_op=True) |
| 85 | +def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression: |
| 86 | + return sg.parse_one( |
| 87 | + op.sql_template.format( |
| 88 | + *[operand.expr.sql(dialect="bigquery") for operand in operands] |
| 89 | + ), |
| 90 | + dialect="bigquery", |
| 91 | + ) |
| 92 | + |
| 93 | + |
83 | 94 | @register_unary_op(ops.isnull_op) |
84 | 95 | def _(expr: TypedExpr) -> sge.Expression: |
85 | 96 | return sge.Is(this=expr.expr, expression=sge.Null()) |
@@ -148,6 +159,30 @@ def _(*cases_and_outputs: TypedExpr) -> sge.Expression: |
148 | 159 | ) |
149 | 160 |
|
150 | 161 |
|
| 162 | +@register_nary_op(ops.RowKey) |
| 163 | +def _(*values: TypedExpr) -> sge.Expression: |
| 164 | + # All inputs into hash must be non-null or resulting hash will be null |
| 165 | + str_values = [_convert_to_nonnull_string_sqlglot(value) for value in values] |
| 166 | + |
| 167 | + full_row_hash_p1 = sge.func("FARM_FINGERPRINT", sge.Concat(expressions=str_values)) |
| 168 | + |
| 169 | + # By modifying value slightly, we get another hash uncorrelated with the first |
| 170 | + full_row_hash_p2 = sge.func( |
| 171 | + "FARM_FINGERPRINT", sge.Concat(expressions=[*str_values, sge.convert("_")]) |
| 172 | + ) |
| 173 | + |
| 174 | + # Used to disambiguate between identical rows (which will have identical hash) |
| 175 | + random_hash_p3 = sge.func("RAND") |
| 176 | + |
| 177 | + return sge.Concat( |
| 178 | + expressions=[ |
| 179 | + sge.Cast(this=full_row_hash_p1, to="STRING"), |
| 180 | + sge.Cast(this=full_row_hash_p2, to="STRING"), |
| 181 | + sge.Cast(this=random_hash_p3, to="STRING"), |
| 182 | + ] |
| 183 | + ) |
| 184 | + |
| 185 | + |
151 | 186 | # Helper functions |
152 | 187 | def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: |
153 | 188 | from_type = expr.dtype |
@@ -207,3 +242,32 @@ def _cast(expr: sge.Expression, to: str, safe: bool): |
207 | 242 | return sge.TryCast(this=expr, to=to) |
208 | 243 | else: |
209 | 244 | return sge.Cast(this=expr, to=to) |
| 245 | + |
| 246 | + |
| 247 | +def _convert_to_nonnull_string_sqlglot(expr: TypedExpr) -> sge.Expression: |
| 248 | + col_type = expr.dtype |
| 249 | + sg_expr = expr.expr |
| 250 | + |
| 251 | + if col_type == dtypes.STRING_DTYPE: |
| 252 | + result = sg_expr |
| 253 | + elif ( |
| 254 | + dtypes.is_numeric(col_type) |
| 255 | + or dtypes.is_time_or_date_like(col_type) |
| 256 | + or col_type == dtypes.BYTES_DTYPE |
| 257 | + ): |
| 258 | + result = sge.Cast(this=sg_expr, to="STRING") |
| 259 | + elif col_type == dtypes.GEO_DTYPE: |
| 260 | + result = sge.func("ST_ASTEXT", sg_expr) |
| 261 | + else: |
| 262 | + # TO_JSON_STRING works with all data types, but isn't the most efficient |
| 263 | + # Needed for JSON, STRUCT and ARRAY datatypes |
| 264 | + result = sge.func("TO_JSON_STRING", sg_expr) |
| 265 | + |
| 266 | + # Escape backslashes and use backslash as delineator |
| 267 | + escaped = sge.func( |
| 268 | + "REPLACE", |
| 269 | + sge.func("COALESCE", result, sge.convert("")), |
| 270 | + sge.convert("\\"), |
| 271 | + sge.convert("\\\\"), |
| 272 | + ) |
| 273 | + return sge.Concat(expressions=[sge.convert("\\"), escaped]) |
0 commit comments