Skip to content

Commit 040a2c9

Browse files
committed
refactor: fix agg_ops.ProductOp for test_dataframe_groupby_analytic
1 parent a553b47 commit 040a2c9

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -426,24 +426,28 @@ def _(
426426
column: typed_expr.TypedExpr,
427427
window: typing.Optional[window_spec.WindowSpec] = None,
428428
) -> sge.Expression:
429+
expr = column.expr
430+
if column.dtype == dtypes.BOOL_DTYPE:
431+
expr = sge.Cast(this=expr, to="INT64")
432+
429433
# Need to short-circuit as log with zeroes is illegal sql
430-
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))
434+
is_zero = sge.EQ(this=expr, expression=sge.convert(0))
431435

432436
# There is no product sql aggregate function, so must implement as a sum of logs, and then
433437
# apply power after. Note, log and power base must be equal! This impl uses natural log.
434-
logs = (
435-
sge.Case()
436-
.when(is_zero, sge.convert(0))
437-
.else_(sge.func("LN", sge.func("ABS", column.expr)))
438+
logs = sge.If(
439+
this=is_zero,
440+
true=sge.convert(0),
441+
false=sge.func("LOG", sge.convert(2), sge.func("ABS", expr)),
438442
)
439443
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
440-
magnitude = sge.func("EXP", logs_sum)
444+
magnitude = sge.func("POWER", sge.convert(2), logs_sum)
441445

442446
# Can't determine sign from logs, so have to determine parity of count of negative inputs
443447
is_negative = (
444448
sge.Case()
445449
.when(
446-
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
450+
sge.EQ(this=sge.func("SIGN", expr), expression=sge.convert(-1)),
447451
sge.convert(1),
448452
)
449453
.else_(sge.convert(0))
@@ -461,11 +465,7 @@ def _(
461465
.else_(
462466
sge.Mul(
463467
this=magnitude,
464-
expression=sge.If(
465-
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
466-
true=sge.convert(-1),
467-
false=sge.convert(1),
468-
),
468+
expression=sge.func("POWER", sge.convert(-1), negative_count_parity),
469469
)
470470
)
471471
)

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
CASE
88
WHEN LOGICAL_OR(`int64_col` = 0)
99
THEN 0
10-
ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1)
10+
ELSE POWER(2, SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2)))) * POWER(-1, MOD(SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END), 2))
1111
END AS `bfcol_1`
1212
FROM `bfcte_0`
1313
)

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ WITH `bfcte_0` AS (
99
CASE
1010
WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`)
1111
THEN 0
12-
ELSE EXP(
13-
SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`)
14-
) * IF(
12+
ELSE POWER(
13+
2,
14+
SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2))) OVER (PARTITION BY `string_col`)
15+
) * POWER(
16+
-1,
1517
MOD(
16-
SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
18+
SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
1719
2
18-
) = 1,
19-
-1,
20-
1
20+
)
2121
)
2222
END AS `bfcol_2`
2323
FROM `bfcte_0`

0 commit comments

Comments
 (0)