Skip to content

Commit 460075e

Browse files
fix: Fix bug with DataFrame.agg for string values
1 parent 942e66c commit 460075e

File tree

3 files changed

+81
-18
lines changed

3 files changed

+81
-18
lines changed

bigframes/core/blocks.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def resolve_label_exact_or_error(self, label: Label) -> str:
278278
raises an error. If there is no such a column, raises an error too."""
279279
col_id = self.resolve_label_exact(label)
280280
if col_id is None:
281-
raise ValueError(f"Label {label} not found. {constants.FEEDBACK_LINK}")
281+
raise KeyError(f"Label {label} not found. {constants.FEEDBACK_LINK}")
282282
return col_id
283283

284284
@functools.cached_property
@@ -1996,7 +1996,7 @@ def _generate_resample_label(
19961996
return block.set_index([resample_label_id])
19971997

19981998
def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
1999-
dtype = None
1999+
input_dtypes = []
20002000
input_columns: list[Optional[str]] = []
20012001
for uvalue in utils.index_as_tuples(stack_labels):
20022002
label_to_match = (*col_label, *uvalue)
@@ -2006,15 +2006,18 @@ def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
20062006
matching_ids = self.label_to_col_id.get(label_to_match, [])
20072007
input_id = matching_ids[0] if len(matching_ids) > 0 else None
20082008
if input_id:
2009-
if dtype and dtype != self._column_type(input_id):
2010-
raise NotImplementedError(
2011-
"Cannot stack columns with non-matching dtypes."
2012-
)
2013-
else:
2014-
dtype = self._column_type(input_id)
2009+
input_dtypes.append(self._column_type(input_id))
20152010
input_columns.append(input_id)
20162011
# Input column i is the first one that
2017-
return tuple(input_columns), dtype or pd.Float64Dtype()
2012+
if len(input_dtypes) > 0:
2013+
output_dtype = bigframes.dtypes.lcd_type(*input_dtypes)
2014+
if output_dtype is None:
2015+
raise NotImplementedError(
2016+
"Cannot stack columns with non-matching dtypes."
2017+
)
2018+
else:
2019+
output_dtype = pd.Float64Dtype()
2020+
return tuple(input_columns), output_dtype
20182021

20192022
def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype:
20202023
col_offset = self.value_columns.index(col_id)

bigframes/dataframe.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,14 +3004,42 @@ def agg(
30043004
if utils.is_dict_like(func):
30053005
# Must check dict-like first because dictionaries are list-like
30063006
# according to Pandas.
3007-
agg_cols = []
3008-
for col_label, agg_func in func.items():
3009-
agg_cols.append(self[col_label].agg(agg_func))
3010-
3011-
from bigframes.core.reshape import api as reshape
3012-
3013-
return reshape.concat(agg_cols, axis=1)
30143007

3008+
aggs = []
3009+
labels = []
3010+
funcnames = []
3011+
for col_label, agg_func in func.items():
3012+
agg_func_list = agg_func if utils.is_list_like(agg_func) else [agg_func]
3013+
col_id = self._block.resolve_label_exact_or_error(col_label)
3014+
for agg_func in agg_func_list:
3015+
agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func))
3016+
agg_expr = (
3017+
ex.UnaryAggregation(agg_op, ex.deref(col_id))
3018+
if isinstance(agg_op, agg_ops.UnaryAggregateOp)
3019+
else ex.NullaryAggregation(agg_op)
3020+
)
3021+
aggs.append(agg_expr)
3022+
labels.append(col_label)
3023+
funcnames.append(agg_func)
3024+
3025+
# if any list in dict values, format output differently
3026+
if any(utils.is_list_like(v) for v in func.values()):
3027+
new_index, _ = self.columns.reindex(labels)
3028+
new_index = utils.combine_indices(new_index, pandas.Index(funcnames))
3029+
agg_block, _ = self._block.aggregate(
3030+
aggregations=aggs, column_labels=new_index
3031+
)
3032+
return DataFrame(agg_block).stack().droplevel(0, axis="index")
3033+
else:
3034+
new_index, _ = self.columns.reindex(labels)
3035+
agg_block, _ = self._block.aggregate(
3036+
aggregations=aggs, column_labels=new_index
3037+
)
3038+
return bigframes.series.Series(
3039+
agg_block.transpose(
3040+
single_row_mode=True, original_row_index=pandas.Index([None])
3041+
)
3042+
)
30153043
elif utils.is_list_like(func):
30163044
aggregations = [agg_ops.lookup_agg_func(f) for f in func]
30173045

@@ -3027,7 +3055,7 @@ def agg(
30273055
)
30283056
)
30293057

3030-
else:
3058+
else: # function name string
30313059
return bigframes.series.Series(
30323060
self._block.aggregate_all_and_stack(
30333061
agg_ops.lookup_agg_func(typing.cast(str, func))

tests/system/small/test_dataframe.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5529,7 +5529,7 @@ def test_astype_invalid_type_fail(scalars_dfs):
55295529
bf_df.astype(123)
55305530

55315531

5532-
def test_agg_with_dict(scalars_dfs):
5532+
def test_agg_with_dict_lists(scalars_dfs):
55335533
bf_df, pd_df = scalars_dfs
55345534
agg_funcs = {
55355535
"int64_too": ["min", "max"],
@@ -5544,6 +5544,38 @@ def test_agg_with_dict(scalars_dfs):
55445544
)
55455545

55465546

5547+
def test_agg_with_dict_list_and_str(scalars_dfs):
5548+
bf_df, pd_df = scalars_dfs
5549+
agg_funcs = {
5550+
"int64_too": ["min", "max"],
5551+
"int64_col": "sum",
5552+
}
5553+
5554+
bf_result = bf_df.agg(agg_funcs).to_pandas()
5555+
pd_result = pd_df.agg(agg_funcs)
5556+
5557+
pd.testing.assert_frame_equal(
5558+
bf_result, pd_result, check_dtype=False, check_index_type=False
5559+
)
5560+
5561+
5562+
def test_agg_with_dict_strs(scalars_dfs):
5563+
bf_df, pd_df = scalars_dfs
5564+
agg_funcs = {
5565+
"int64_too": "min",
5566+
"int64_col": "sum",
5567+
"float64_col": "max",
5568+
}
5569+
5570+
bf_result = bf_df.agg(agg_funcs).to_pandas()
5571+
pd_result = pd_df.agg(agg_funcs)
5572+
pd_result.index = pd_result.index.astype("string[pyarrow]")
5573+
5574+
pd.testing.assert_series_equal(
5575+
bf_result, pd_result, check_dtype=False, check_index_type=False
5576+
)
5577+
5578+
55475579
def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs):
55485580
bf_df, _ = scalars_dfs
55495581
agg_funcs = {

0 commit comments

Comments
 (0)