@@ -33,46 +33,15 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3333 sg_expr = expr .expr
3434
3535 if to_type == dtypes .JSON_DTYPE :
36- if from_type == dtypes .STRING_DTYPE :
37- func_name = "PARSE_JSON_IN_SAFE" if op .safe else "PARSE_JSON"
38- return sge .func (func_name , sg_expr )
39- if from_type in (dtypes .INT_DTYPE , dtypes .BOOL_DTYPE , dtypes .FLOAT_DTYPE ):
40- sg_expr = sge .Cast (this = sg_expr , to = "STRING" )
41- return sge .func ("PARSE_JSON" , sg_expr )
42- raise TypeError (f"Cannot cast from { from_type } to { to_type } " )
36+ return _cast_to_json (expr , op )
4337
4438 if from_type == dtypes .JSON_DTYPE :
45- func_name = ""
46- if to_type == dtypes .INT_DTYPE :
47- func_name = "INT64"
48- elif to_type == dtypes .FLOAT_DTYPE :
49- func_name = "FLOAT64"
50- elif to_type == dtypes .BOOL_DTYPE :
51- func_name = "BOOL"
52- elif to_type == dtypes .STRING_DTYPE :
53- func_name = "STRING"
54- if func_name :
55- func_name = "SAFE." + func_name if op .safe else func_name
56- return sge .func (func_name , sg_expr )
57- raise TypeError (f"Cannot cast from { from_type } to { to_type } " )
39+ return _cast_from_json (expr , op )
5840
5941 if to_type == dtypes .INT_DTYPE :
60- # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
61- if from_type == dtypes .DATETIME_DTYPE :
62- sg_expr = _cast (sg_expr , "TIMESTAMP" , op .safe )
63- return sge .func ("UNIX_MICROS" , sg_expr )
64- if from_type == dtypes .TIMESTAMP_DTYPE :
65- return sge .func ("UNIX_MICROS" , sg_expr )
66- if from_type == dtypes .TIME_DTYPE :
67- return sge .func (
68- "TIME_DIFF" ,
69- _cast (sg_expr , "TIME" , op .safe ),
70- sge .convert ("00:00:00" ),
71- "MICROSECOND" ,
72- )
73- if from_type == dtypes .NUMERIC_DTYPE or from_type == dtypes .FLOAT_DTYPE :
74- sg_expr = sge .func ("TRUNC" , sg_expr )
75- return _cast (sg_expr , sg_to_type , op .safe )
42+ result = _cast_to_int (expr , op )
43+ if result is not None :
44+ return result
7645
7746 if to_type == dtypes .FLOAT_DTYPE and from_type == dtypes .BOOL_DTYPE :
7847 sg_expr = _cast (sg_expr , "INT64" , op .safe )
@@ -124,6 +93,59 @@ def _(expr: TypedExpr) -> sge.Expression:
12493
12594
12695# Helper functions
96+ def _cast_to_json (expr : TypedExpr , op : ops .AsTypeOp ) -> sge .Expression :
97+ from_type = expr .dtype
98+ sg_expr = expr .expr
99+
100+ if from_type == dtypes .STRING_DTYPE :
101+ func_name = "PARSE_JSON_IN_SAFE" if op .safe else "PARSE_JSON"
102+ return sge .func (func_name , sg_expr )
103+ if from_type in (dtypes .INT_DTYPE , dtypes .BOOL_DTYPE , dtypes .FLOAT_DTYPE ):
104+ sg_expr = sge .Cast (this = sg_expr , to = "STRING" )
105+ return sge .func ("PARSE_JSON" , sg_expr )
106+ raise TypeError (f"Cannot cast from { from_type } to { dtypes .JSON_DTYPE } " )
107+
108+
109+ def _cast_from_json (expr : TypedExpr , op : ops .AsTypeOp ) -> sge .Expression :
110+ to_type = op .to_type
111+ sg_expr = expr .expr
112+ func_name = ""
113+ if to_type == dtypes .INT_DTYPE :
114+ func_name = "INT64"
115+ elif to_type == dtypes .FLOAT_DTYPE :
116+ func_name = "FLOAT64"
117+ elif to_type == dtypes .BOOL_DTYPE :
118+ func_name = "BOOL"
119+ elif to_type == dtypes .STRING_DTYPE :
120+ func_name = "STRING"
121+ if func_name :
122+ func_name = "SAFE." + func_name if op .safe else func_name
123+ return sge .func (func_name , sg_expr )
124+ raise TypeError (f"Cannot cast from { dtypes .JSON_DTYPE } to { to_type } " )
125+
126+
127+ def _cast_to_int (expr : TypedExpr , op : ops .AsTypeOp ) -> sge .Expression | None :
128+ from_type = expr .dtype
129+ sg_expr = expr .expr
130+ # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
131+ if from_type == dtypes .DATETIME_DTYPE :
132+ sg_expr = _cast (sg_expr , "TIMESTAMP" , op .safe )
133+ return sge .func ("UNIX_MICROS" , sg_expr )
134+ if from_type == dtypes .TIMESTAMP_DTYPE :
135+ return sge .func ("UNIX_MICROS" , sg_expr )
136+ if from_type == dtypes .TIME_DTYPE :
137+ return sge .func (
138+ "TIME_DIFF" ,
139+ _cast (sg_expr , "TIME" , op .safe ),
140+ sge .convert ("00:00:00" ),
141+ "MICROSECOND" ,
142+ )
143+ if from_type == dtypes .NUMERIC_DTYPE or from_type == dtypes .FLOAT_DTYPE :
144+ sg_expr = sge .func ("TRUNC" , sg_expr )
145+ return _cast (sg_expr , "INT64" , op .safe )
146+ return None
147+
148+
127149def _cast (expr : sge .Expression , to : str , safe : bool ):
128150 if safe :
129151 return sge .TryCast (this = expr , to = to )
0 commit comments