Skip to content

Commit 89ef3d8

Browse files
committed
refactor: fix the __getitem__ of string methods
1 parent 0b14b17 commit 89ef3d8

File tree

2 files changed

+116
-85
lines changed

2 files changed

+116
-85
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
23+
from bigframes.core.compile.sqlglot.expressions.string_ops import (
24+
string_index,
25+
string_slice,
26+
)
2327
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2428
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2529
import bigframes.dtypes as dtypes
@@ -31,7 +35,7 @@
3135
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
3236
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
3337
if expr.dtype == dtypes.STRING_DTYPE:
34-
return _string_index(expr, op)
38+
return string_index(expr, op.index)
3539

3640
return sge.Bracket(
3741
this=expr.expr,
@@ -71,29 +75,10 @@ def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
7175

7276
@register_unary_op(ops.ArraySliceOp, pass_op=True)
7377
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
74-
slice_idx = sg.to_identifier("slice_idx")
75-
76-
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
77-
78-
if op.stop is not None:
79-
conditions.append(slice_idx < op.stop)
80-
81-
# local name for each element in the array
82-
el = sg.to_identifier("el")
83-
84-
selected_elements = (
85-
sge.select(el)
86-
.from_(
87-
sge.Unnest(
88-
expressions=[expr.expr],
89-
alias=sge.TableAlias(columns=[el]),
90-
offset=slice_idx,
91-
)
92-
)
93-
.where(*conditions)
94-
)
95-
96-
return sge.array(selected_elements)
78+
if expr.dtype == dtypes.STRING_DTYPE:
79+
return string_slice(expr, op.start, op.stop)
80+
else:
81+
return _array_slice(expr, op)
9782

9883

9984
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
@@ -120,14 +105,51 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
120105
return typed_expr.expr
121106

122107

123-
def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
124-
sub_str = sge.Substring(
125-
this=expr.expr,
126-
start=sge.convert(op.index + 1),
127-
length=sge.convert(1),
108+
def _string_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
109+
# local name for each element in the array
110+
el = sg.to_identifier("el")
111+
# local name for the index in the array
112+
slice_idx = sg.to_identifier("slice_idx")
113+
114+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
115+
if op.stop is not None:
116+
conditions.append(slice_idx < op.stop)
117+
118+
selected_elements = (
119+
sge.select(el)
120+
.from_(
121+
sge.Unnest(
122+
expressions=[expr.expr],
123+
alias=sge.TableAlias(columns=[el]),
124+
offset=slice_idx,
125+
)
126+
)
127+
.where(*conditions)
128128
)
129-
return sge.If(
130-
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
131-
true=sub_str,
132-
false=sge.Null(),
129+
130+
return sge.array(selected_elements)
131+
132+
133+
def _array_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
134+
# local name for each element in the array
135+
el = sg.to_identifier("el")
136+
# local name for the index in the array
137+
slice_idx = sg.to_identifier("slice_idx")
138+
139+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
140+
if op.stop is not None:
141+
conditions.append(slice_idx < op.stop)
142+
143+
selected_elements = (
144+
sge.select(el)
145+
.from_(
146+
sge.Unnest(
147+
expressions=[expr.expr],
148+
alias=sge.TableAlias(columns=[el]),
149+
offset=slice_idx,
150+
)
151+
)
152+
.where(*conditions)
133153
)
154+
155+
return sge.array(selected_elements)

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 61 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -253,32 +253,81 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253253

254254
@register_unary_op(ops.StrGetOp, pass_op=True)
255255
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
256+
return string_index(expr, op.i)
257+
258+
259+
@register_unary_op(ops.StrSliceOp, pass_op=True)
260+
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
261+
return string_slice(expr, op.start, op.end)
262+
263+
264+
@register_unary_op(ops.upper_op)
265+
def _(expr: TypedExpr) -> sge.Expression:
266+
return sge.Upper(this=expr.expr)
267+
268+
269+
@register_binary_op(ops.strconcat_op)
270+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
271+
return sge.Concat(expressions=[left.expr, right.expr])
272+
273+
274+
@register_unary_op(ops.ZfillOp, pass_op=True)
275+
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
276+
length_expr = sge.Greatest(
277+
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
278+
)
279+
return sge.Case(
280+
ifs=[
281+
sge.If(
282+
this=sge.func(
283+
"STARTS_WITH",
284+
expr.expr,
285+
sge.convert("-"),
286+
),
287+
true=sge.Concat(
288+
expressions=[
289+
sge.convert("-"),
290+
sge.func(
291+
"LPAD",
292+
sge.Substring(this=expr.expr, start=sge.convert(2)),
293+
length_expr - 1,
294+
sge.convert("0"),
295+
),
296+
]
297+
),
298+
)
299+
],
300+
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
301+
)
302+
303+
304+
def string_index(expr: TypedExpr, index: int) -> sge.Expression:
256305
sub_str = sge.Substring(
257306
this=expr.expr,
258-
start=sge.convert(op.i + 1),
307+
start=sge.convert(index + 1),
259308
length=sge.convert(1),
260309
)
261-
262310
return sge.If(
263311
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
264312
true=sub_str,
265313
false=sge.Null(),
266314
)
267315

268316

269-
@register_unary_op(ops.StrSliceOp, pass_op=True)
270-
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
317+
def string_slice(
318+
expr: TypedExpr, op_start: typing.Optional[int], op_end: typing.Optional[int]
319+
) -> sge.Expression:
271320
column_length = sge.Length(this=expr.expr)
272-
if op.start is None:
321+
if op_start is None:
273322
start = 0
274323
else:
275-
start = op.start
324+
start = op_start
276325

277326
start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
278327
length_expr: typing.Optional[sge.Expression]
279-
if op.end is None:
328+
if op_end is None:
280329
length_expr = None
281-
elif op.end < 0:
330+
elif op_end < 0:
282331
if start < 0:
283332
start_expr = sge.Greatest(
284333
expressions=[
@@ -289,7 +338,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
289338
length_expr = sge.Greatest(
290339
expressions=[
291340
sge.convert(0),
292-
column_length + sge.convert(op.end),
341+
column_length + sge.convert(op_end),
293342
]
294343
) - sge.Greatest(
295344
expressions=[
@@ -301,7 +350,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
301350
length_expr = sge.Greatest(
302351
expressions=[
303352
sge.convert(0),
304-
column_length + sge.convert(op.end - start),
353+
column_length + sge.convert(op_end - start),
305354
]
306355
)
307356
else: # op.end >= 0
@@ -312,57 +361,17 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
312361
column_length + sge.convert(start + 1),
313362
]
314363
)
315-
length_expr = sge.convert(op.end) - sge.Greatest(
364+
length_expr = sge.convert(op_end) - sge.Greatest(
316365
expressions=[
317366
sge.convert(0),
318367
column_length + sge.convert(start),
319368
]
320369
)
321370
else:
322-
length_expr = sge.convert(op.end - start)
371+
length_expr = sge.convert(op_end - start)
323372

324373
return sge.Substring(
325374
this=expr.expr,
326375
start=start_expr,
327376
length=length_expr,
328377
)
329-
330-
331-
@register_unary_op(ops.upper_op)
332-
def _(expr: TypedExpr) -> sge.Expression:
333-
return sge.Upper(this=expr.expr)
334-
335-
336-
@register_binary_op(ops.strconcat_op)
337-
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
338-
return sge.Concat(expressions=[left.expr, right.expr])
339-
340-
341-
@register_unary_op(ops.ZfillOp, pass_op=True)
342-
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
343-
length_expr = sge.Greatest(
344-
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
345-
)
346-
return sge.Case(
347-
ifs=[
348-
sge.If(
349-
this=sge.func(
350-
"STARTS_WITH",
351-
expr.expr,
352-
sge.convert("-"),
353-
),
354-
true=sge.Concat(
355-
expressions=[
356-
sge.convert("-"),
357-
sge.func(
358-
"LPAD",
359-
sge.Substring(this=expr.expr, start=sge.convert(2)),
360-
length_expr - 1,
361-
sge.convert("0"),
362-
),
363-
]
364-
),
365-
)
366-
],
367-
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
368-
)

0 commit comments

Comments
 (0)