Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,17 @@ def _(
return apply_window_if_present(sge.func("MIN", column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.NuniqueOp)
def _(
op: agg_ops.NuniqueOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(
sge.func("COUNT", sge.Distinct(expressions=[column.expr])), window
)


@UNARY_OP_REGISTRATION.register(agg_ops.PopVarOp)
def _(
op: agg_ops.PopVarOp,
Expand All @@ -400,6 +411,58 @@ def _(
return apply_window_if_present(expr, window)


@UNARY_OP_REGISTRATION.register(agg_ops.ProductOp)
def _(
op: agg_ops.ProductOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# Need to short-circuit as log with zeroes is illegal sql
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))

# There is no product sql aggregate function, so must implement as a sum of logs, and then
# apply power after. Note, log and power base must be equal! This impl uses natural log.
logs = (
sge.Case()
.when(is_zero, sge.convert(0))
.else_(sge.func("LN", sge.func("ABS", column.expr)))
)
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
magnitude = sge.func("EXP", logs_sum)

# Can't determine sign from logs, so have to determine parity of count of negative inputs
is_negative = (
sge.Case()
.when(
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
sge.convert(1),
)
.else_(sge.convert(0))
)
negative_count = apply_window_if_present(sge.func("SUM", is_negative), window)
negative_count_parity = sge.Mod(
this=negative_count, expression=sge.convert(2)
) # 1 if result should be negative, otherwise 0

any_zeroes = apply_window_if_present(sge.func("LOGICAL_OR", is_zero), window)

float_result = (
sge.Case()
.when(any_zeroes, sge.convert(0))
.else_(
sge.Mul(
this=magnitude,
expression=sge.If(
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
true=sge.convert(-1),
false=sge.convert(1),
),
)
)
)
return float_result


@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
def _(
op: agg_ops.QcutOp,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
COUNT(DISTINCT `int64_col`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
CASE
WHEN LOGICAL_OR(`int64_col` = 0)
THEN 0
ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`,
`string_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`)
THEN 0
ELSE EXP(
SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`)
) * IF(
MOD(
SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
2
) = 1,
-1,
1
)
END AS `bfcol_2`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `agg_int64`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")


def test_nunique(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.NuniqueOp().as_expr(col_name)
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
col_names = ["int64_col", "bool_col"]
bf_df = scalar_types_df[col_names]
Expand All @@ -434,6 +443,25 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql_window, "window_out.sql")


def test_product(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.ProductOp().as_expr(col_name)
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])

snapshot.assert_match(sql, "out.sql")

bf_df_str = scalar_types_df[[col_name, "string_col"]]
window_partition = window_spec.WindowSpec(
grouping_keys=(expression.deref("string_col"),),
)
sql_window_partition = _apply_unary_window_op(
bf_df_str, agg_expr, window_partition, "agg_int64"
)

snapshot.assert_match(sql_window_partition, "window_partition_out.sql")


def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
Expand Down