Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions bigframes/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,12 @@ def get_function_name(func, package_requirements=None, is_row_processor=False):
return f"bigframes_{function_hash}"


def _apply_unary_ops(
def _apply_ops_to_sql(
obj: bpd.DataFrame,
ops_list: Sequence[ex.Expression],
new_names: Sequence[str],
) -> str:
"""Applies a list of unary ops to the given DataFrame and returns the SQL
"""Applies a list of ops to the given DataFrame and returns the SQL
representing the resulting DataFrame."""
array_value = obj._block.expr
result, old_names = array_value.compute_values(ops_list)
Expand Down Expand Up @@ -485,13 +485,6 @@ def _apply_nary_op(
) -> str:
"""Applies a nary op to the given DataFrame and return the SQL representing
the resulting DataFrame."""
array_value = obj._block.expr
op_expr = op.as_expr(*args)
result, col_ids = array_value.compute_values([op_expr])

# Rename columns for deterministic golden SQL results.
assert len(col_ids) == 1
result = result.rename_columns({col_ids[0]: args[0]}).select_columns([args[0]])

sql = result.session._executor.to_sql(result, enable_cache=False)
sql = _apply_ops_to_sql(obj, [op_expr], [args[0]]) # type: ignore
return sql
24 changes: 12 additions & 12 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot):
output_schema=None,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -58,7 +58,7 @@ def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, sn
output_schema="x INT64, y FLOAT64",
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -82,7 +82,7 @@ def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snap
output_schema=None,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -100,7 +100,7 @@ def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot):
model_params=None,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -125,7 +125,7 @@ def test_ai_generate_bool_with_model_param(
model_params=json.dumps(dict()),
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -144,7 +144,7 @@ def test_ai_generate_int(scalar_types_df: dataframe.DataFrame, snapshot):
model_params=None,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -170,7 +170,7 @@ def test_ai_generate_int_with_model_param(
model_params=json.dumps(dict()),
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -189,7 +189,7 @@ def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot):
model_params=None,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -215,7 +215,7 @@ def test_ai_generate_double_with_model_param(
model_params=json.dumps(dict()),
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -230,7 +230,7 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand All @@ -246,7 +246,7 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(scalar_types_df, [op.as_expr(col_name)], ["result"])
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])

snapshot.assert_match(sql, "out.sql")

Expand All @@ -259,7 +259,7 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot):
col_name = "string_list_col"
bf_df = repeated_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.ArrayToStringOp(delimiter=".").as_expr(col_name)], [col_name]
)

Expand All @@ -35,7 +35,7 @@ def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot):
def test_array_index(repeated_types_df: bpd.DataFrame, snapshot):
col_name = "string_list_col"
bf_df = repeated_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [convert_index(1).as_expr(col_name)], [col_name]
)

Expand All @@ -45,7 +45,7 @@ def test_array_index(repeated_types_df: bpd.DataFrame, snapshot):
def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot):
col_name = "string_list_col"
bf_df = repeated_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [convert_slice(slice(1, None)).as_expr(col_name)], [col_name]
)

Expand All @@ -55,7 +55,7 @@ def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot)
def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snapshot):
col_name = "string_list_col"
bf_df = repeated_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [convert_slice(slice(1, 5)).as_expr(col_name)], [col_name]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_is_in(scalar_types_df: bpd.DataFrame, snapshot):
"float_in_ints": ops.IsInOp(values=(1, 2, 3, None)).as_expr(float_col),
}

sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
snapshot.assert_match(sql, "out.sql")


Expand Down
50 changes: 27 additions & 23 deletions tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
def test_date(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.date_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.date_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_day(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.day_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.day_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")

Expand All @@ -43,14 +43,14 @@ def test_dayofweek(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[col_names]
ops_map = {col_name: ops.dayofweek_op.as_expr(col_name) for col_name in col_names}

sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
snapshot.assert_match(sql, "out.sql")


def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.dayofyear_op.as_expr(col_name)], [col_name]
)

Expand All @@ -75,7 +75,7 @@ def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot):
"datetime_col_us": ops.FloorDtOp("us").as_expr("datetime_col"),
}

sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
snapshot.assert_match(sql, "out.sql")


Expand All @@ -85,7 +85,7 @@ def test_floor_dt_op_invalid_freq(scalar_types_df: bpd.DataFrame):
with pytest.raises(
NotImplementedError, match="Unsupported freq paramater: invalid"
):
utils._apply_unary_ops(
utils._apply_ops_to_sql(
bf_df,
[ops.FloorDtOp(freq="invalid").as_expr(col_name)], # type:ignore
[col_name],
Expand All @@ -95,31 +95,31 @@ def test_floor_dt_op_invalid_freq(scalar_types_df: bpd.DataFrame):
def test_hour(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.hour_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.hour_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_minute(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.minute_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.minute_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_month(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.month_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.month_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_normalize(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.normalize_op.as_expr(col_name)], [col_name]
)

Expand All @@ -129,23 +129,23 @@ def test_normalize(scalar_types_df: bpd.DataFrame, snapshot):
def test_quarter(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.quarter_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.quarter_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_second(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.second_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.second_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_strftime(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.StrftimeOp("%Y-%m-%d").as_expr(col_name)], [col_name]
)

Expand All @@ -155,15 +155,15 @@ def test_strftime(scalar_types_df: bpd.DataFrame, snapshot):
def test_time(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.time_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.time_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_to_datetime(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.ToDatetimeOp().as_expr(col_name)], [col_name]
)

Expand All @@ -173,7 +173,7 @@ def test_to_datetime(scalar_types_df: bpd.DataFrame, snapshot):
def test_to_timestamp(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.ToTimestampOp().as_expr(col_name)], [col_name]
)

Expand All @@ -183,7 +183,7 @@ def test_to_timestamp(scalar_types_df: bpd.DataFrame, snapshot):
def test_unix_micros(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.UnixMicros().as_expr(col_name)], [col_name]
)

Expand All @@ -193,7 +193,7 @@ def test_unix_micros(scalar_types_df: bpd.DataFrame, snapshot):
def test_unix_millis(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.UnixMillis().as_expr(col_name)], [col_name]
)

Expand All @@ -203,7 +203,7 @@ def test_unix_millis(scalar_types_df: bpd.DataFrame, snapshot):
def test_unix_seconds(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(
sql = utils._apply_ops_to_sql(
bf_df, [ops.UnixSeconds().as_expr(col_name)], [col_name]
)

Expand All @@ -213,31 +213,35 @@ def test_unix_seconds(scalar_types_df: bpd.DataFrame, snapshot):
def test_year(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.year_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.year_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_iso_day(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.iso_day_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(bf_df, [ops.iso_day_op.as_expr(col_name)], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_iso_week(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.iso_week_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(
bf_df, [ops.iso_week_op.as_expr(col_name)], [col_name]
)

snapshot.assert_match(sql, "out.sql")


def test_iso_year(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_unary_ops(bf_df, [ops.iso_year_op.as_expr(col_name)], [col_name])
sql = utils._apply_ops_to_sql(
bf_df, [ops.iso_year_op.as_expr(col_name)], [col_name]
)

snapshot.assert_match(sql, "out.sql")

Expand Down
Loading