Skip to content

Commit 26dddb2

Browse files
committed
address comments
1 parent 2bb5fdd commit 26dddb2

File tree

1 file changed

+58
-36
lines changed

1 file changed

+58
-36
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,46 +33,15 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3333
sg_expr = expr.expr
3434

3535
if to_type == dtypes.JSON_DTYPE:
36-
if from_type == dtypes.STRING_DTYPE:
37-
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
38-
return sge.func(func_name, sg_expr)
39-
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
40-
sg_expr = sge.Cast(this=sg_expr, to="STRING")
41-
return sge.func("PARSE_JSON", sg_expr)
42-
raise TypeError(f"Cannot cast from {from_type} to {to_type}")
36+
return _cast_to_json(expr, op)
4337

4438
if from_type == dtypes.JSON_DTYPE:
45-
func_name = ""
46-
if to_type == dtypes.INT_DTYPE:
47-
func_name = "INT64"
48-
elif to_type == dtypes.FLOAT_DTYPE:
49-
func_name = "FLOAT64"
50-
elif to_type == dtypes.BOOL_DTYPE:
51-
func_name = "BOOL"
52-
elif to_type == dtypes.STRING_DTYPE:
53-
func_name = "STRING"
54-
if func_name:
55-
func_name = "SAFE." + func_name if op.safe else func_name
56-
return sge.func(func_name, sg_expr)
57-
raise TypeError(f"Cannot cast from {from_type} to {to_type}")
39+
return _cast_from_json(expr, op)
5840

5941
if to_type == dtypes.INT_DTYPE:
60-
# Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
61-
if from_type == dtypes.DATETIME_DTYPE:
62-
sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe)
63-
return sge.func("UNIX_MICROS", sg_expr)
64-
if from_type == dtypes.TIMESTAMP_DTYPE:
65-
return sge.func("UNIX_MICROS", sg_expr)
66-
if from_type == dtypes.TIME_DTYPE:
67-
return sge.func(
68-
"TIME_DIFF",
69-
_cast(sg_expr, "TIME", op.safe),
70-
sge.convert("00:00:00"),
71-
"MICROSECOND",
72-
)
73-
if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE:
74-
sg_expr = sge.func("TRUNC", sg_expr)
75-
return _cast(sg_expr, sg_to_type, op.safe)
42+
result = _cast_to_int(expr, op)
43+
if result is not None:
44+
return result
7645

7746
if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE:
7847
sg_expr = _cast(sg_expr, "INT64", op.safe)
@@ -124,6 +93,59 @@ def _(expr: TypedExpr) -> sge.Expression:
12493

12594

12695
# Helper functions
96+
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
97+
from_type = expr.dtype
98+
sg_expr = expr.expr
99+
100+
if from_type == dtypes.STRING_DTYPE:
101+
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
102+
return sge.func(func_name, sg_expr)
103+
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
104+
sg_expr = sge.Cast(this=sg_expr, to="STRING")
105+
return sge.func("PARSE_JSON", sg_expr)
106+
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")
107+
108+
109+
def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
110+
to_type = op.to_type
111+
sg_expr = expr.expr
112+
func_name = ""
113+
if to_type == dtypes.INT_DTYPE:
114+
func_name = "INT64"
115+
elif to_type == dtypes.FLOAT_DTYPE:
116+
func_name = "FLOAT64"
117+
elif to_type == dtypes.BOOL_DTYPE:
118+
func_name = "BOOL"
119+
elif to_type == dtypes.STRING_DTYPE:
120+
func_name = "STRING"
121+
if func_name:
122+
func_name = "SAFE." + func_name if op.safe else func_name
123+
return sge.func(func_name, sg_expr)
124+
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")
125+
126+
127+
def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:
128+
from_type = expr.dtype
129+
sg_expr = expr.expr
130+
# Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
131+
if from_type == dtypes.DATETIME_DTYPE:
132+
sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe)
133+
return sge.func("UNIX_MICROS", sg_expr)
134+
if from_type == dtypes.TIMESTAMP_DTYPE:
135+
return sge.func("UNIX_MICROS", sg_expr)
136+
if from_type == dtypes.TIME_DTYPE:
137+
return sge.func(
138+
"TIME_DIFF",
139+
_cast(sg_expr, "TIME", op.safe),
140+
sge.convert("00:00:00"),
141+
"MICROSECOND",
142+
)
143+
if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE:
144+
sg_expr = sge.func("TRUNC", sg_expr)
145+
return _cast(sg_expr, "INT64", op.safe)
146+
return None
147+
148+
127149
def _cast(expr: sge.Expression, to: str, safe: bool):
128150
if safe:
129151
return sge.TryCast(this=expr, to=to)

0 commit comments

Comments
 (0)