Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def _generate_resample_label(
return block.set_index([resample_label_id])

def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
dtype = None
input_dtypes = []
input_columns: list[Optional[str]] = []
for uvalue in utils.index_as_tuples(stack_labels):
label_to_match = (*col_label, *uvalue)
Expand All @@ -2009,15 +2009,18 @@ def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
matching_ids = self.label_to_col_id.get(label_to_match, [])
input_id = matching_ids[0] if len(matching_ids) > 0 else None
if input_id:
if dtype and dtype != self._column_type(input_id):
raise NotImplementedError(
"Cannot stack columns with non-matching dtypes."
)
else:
dtype = self._column_type(input_id)
input_dtypes.append(self._column_type(input_id))
input_columns.append(input_id)
# Input column i is the first one that
return tuple(input_columns), dtype or pd.Float64Dtype()
if len(input_dtypes) > 0:
output_dtype = bigframes.dtypes.lcd_type(*input_dtypes)
if output_dtype is None:
raise NotImplementedError(
"Cannot stack columns with non-matching dtypes."
)
else:
output_dtype = pd.Float64Dtype()
return tuple(input_columns), output_dtype

def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype:
col_offset = self.value_columns.index(col_id)
Expand Down
46 changes: 38 additions & 8 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,14 +3004,44 @@ def agg(
if utils.is_dict_like(func):
# Must check dict-like first because dictionaries are list-like
# according to Pandas.
agg_cols = []
for col_label, agg_func in func.items():
agg_cols.append(self[col_label].agg(agg_func))

from bigframes.core.reshape import api as reshape

return reshape.concat(agg_cols, axis=1)

aggs = []
labels = []
funcnames = []
for col_label, agg_func in func.items():
agg_func_list = agg_func if utils.is_list_like(agg_func) else [agg_func]
col_id = self._block.resolve_label_exact(col_label)
if col_id is None:
raise KeyError(f"Column {col_label} does not exist")
for agg_func in agg_func_list:
agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func))
agg_expr = (
ex.UnaryAggregation(agg_op, ex.deref(col_id))
if isinstance(agg_op, agg_ops.UnaryAggregateOp)
else ex.NullaryAggregation(agg_op)
)
aggs.append(agg_expr)
labels.append(col_label)
funcnames.append(agg_func)

# if any list in dict values, format output differently
if any(utils.is_list_like(v) for v in func.values()):
new_index, _ = self.columns.reindex(labels)
new_index = utils.combine_indices(new_index, pandas.Index(funcnames))
agg_block, _ = self._block.aggregate(
aggregations=aggs, column_labels=new_index
)
return DataFrame(agg_block).stack().droplevel(0, axis="index")
else:
new_index, _ = self.columns.reindex(labels)
agg_block, _ = self._block.aggregate(
aggregations=aggs, column_labels=new_index
)
return bigframes.series.Series(
agg_block.transpose(
single_row_mode=True, original_row_index=pandas.Index([None])
)
)
elif utils.is_list_like(func):
aggregations = [agg_ops.lookup_agg_func(f) for f in func]

Expand All @@ -3027,7 +3057,7 @@ def agg(
)
)

else:
else: # function name string
return bigframes.series.Series(
self._block.aggregate_all_and_stack(
agg_ops.lookup_agg_func(typing.cast(str, func))
Expand Down
34 changes: 33 additions & 1 deletion tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5538,7 +5538,7 @@ def test_astype_invalid_type_fail(scalars_dfs):
bf_df.astype(123)


def test_agg_with_dict(scalars_dfs):
def test_agg_with_dict_lists(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": ["min", "max"],
Expand All @@ -5553,6 +5553,38 @@ def test_agg_with_dict(scalars_dfs):
)


def test_agg_with_dict_list_and_str(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": ["min", "max"],
"int64_col": "sum",
}

bf_result = bf_df.agg(agg_funcs).to_pandas()
pd_result = pd_df.agg(agg_funcs)

pd.testing.assert_frame_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


def test_agg_with_dict_strs(scalars_dfs):
bf_df, pd_df = scalars_dfs
agg_funcs = {
"int64_too": "min",
"int64_col": "sum",
"float64_col": "max",
}

bf_result = bf_df.agg(agg_funcs).to_pandas()
pd_result = pd_df.agg(agg_funcs)
pd_result.index = pd_result.index.astype("string[pyarrow]")

pd.testing.assert_series_equal(
bf_result, pd_result, check_dtype=False, check_index_type=False
)


def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs):
bf_df, _ = scalars_dfs
agg_funcs = {
Expand Down