diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 1426459912..95dadfc987 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1999,7 +1999,7 @@ def _generate_resample_label( return block.set_index([resample_label_id]) def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index): - dtype = None + input_dtypes = [] input_columns: list[Optional[str]] = [] for uvalue in utils.index_as_tuples(stack_labels): label_to_match = (*col_label, *uvalue) @@ -2009,15 +2009,18 @@ def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index): matching_ids = self.label_to_col_id.get(label_to_match, []) input_id = matching_ids[0] if len(matching_ids) > 0 else None if input_id: - if dtype and dtype != self._column_type(input_id): - raise NotImplementedError( - "Cannot stack columns with non-matching dtypes." - ) - else: - dtype = self._column_type(input_id) + input_dtypes.append(self._column_type(input_id)) input_columns.append(input_id) # Input column i is the first one that - return tuple(input_columns), dtype or pd.Float64Dtype() + if len(input_dtypes) > 0: + output_dtype = bigframes.dtypes.lcd_type(*input_dtypes) + if output_dtype is None: + raise NotImplementedError( + "Cannot stack columns with non-matching dtypes." + ) + else: + output_dtype = pd.Float64Dtype() + return tuple(input_columns), output_dtype def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype: col_offset = self.value_columns.index(col_id) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 495e242f43..1ca5b8b035 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3004,14 +3004,44 @@ def agg( if utils.is_dict_like(func): # Must check dict-like first because dictionaries are list-like # according to Pandas. - agg_cols = [] - for col_label, agg_func in func.items(): - agg_cols.append(self[col_label].agg(agg_func)) - - from bigframes.core.reshape import api as reshape - - return reshape.concat(agg_cols, axis=1) + aggs = [] + labels = [] + funcnames = [] + for col_label, agg_func in func.items(): + agg_func_list = agg_func if utils.is_list_like(agg_func) else [agg_func] + col_id = self._block.resolve_label_exact(col_label) + if col_id is None: + raise KeyError(f"Column {col_label} does not exist") + for agg_func in agg_func_list: + agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func)) + agg_expr = ( + ex.UnaryAggregation(agg_op, ex.deref(col_id)) + if isinstance(agg_op, agg_ops.UnaryAggregateOp) + else ex.NullaryAggregation(agg_op) + ) + aggs.append(agg_expr) + labels.append(col_label) + funcnames.append(agg_func) + + # if any list in dict values, format output differently + if any(utils.is_list_like(v) for v in func.values()): + new_index, _ = self.columns.reindex(labels) + new_index = utils.combine_indices(new_index, pandas.Index(funcnames)) + agg_block, _ = self._block.aggregate( + aggregations=aggs, column_labels=new_index + ) + return DataFrame(agg_block).stack().droplevel(0, axis="index") + else: + new_index, _ = self.columns.reindex(labels) + agg_block, _ = self._block.aggregate( + aggregations=aggs, column_labels=new_index + ) + return bigframes.series.Series( + agg_block.transpose( + single_row_mode=True, original_row_index=pandas.Index([None]) + ) + ) elif utils.is_list_like(func): aggregations = [agg_ops.lookup_agg_func(f) for f in func] @@ -3027,7 +3057,7 @@ def agg( ) ) - else: + else: # function name string return bigframes.series.Series( self._block.aggregate_all_and_stack( agg_ops.lookup_agg_func(typing.cast(str, func)) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index d5446efcd0..e8d156538f 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -5538,7 +5538,7 @@ def test_astype_invalid_type_fail(scalars_dfs): bf_df.astype(123) -def test_agg_with_dict(scalars_dfs): +def test_agg_with_dict_lists(scalars_dfs): bf_df, pd_df = scalars_dfs agg_funcs = { "int64_too": ["min", "max"], @@ -5553,6 +5553,38 @@ def test_agg_with_dict(scalars_dfs): ) +def test_agg_with_dict_list_and_str(scalars_dfs): + bf_df, pd_df = scalars_dfs + agg_funcs = { + "int64_too": ["min", "max"], + "int64_col": "sum", + } + + bf_result = bf_df.agg(agg_funcs).to_pandas() + pd_result = pd_df.agg(agg_funcs) + + pd.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +def test_agg_with_dict_strs(scalars_dfs): + bf_df, pd_df = scalars_dfs + agg_funcs = { + "int64_too": "min", + "int64_col": "sum", + "float64_col": "max", + } + + bf_result = bf_df.agg(agg_funcs).to_pandas() + pd_result = pd_df.agg(agg_funcs) + pd_result.index = pd_result.index.astype("string[pyarrow]") + + pd.testing.assert_series_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs): bf_df, _ = scalars_dfs agg_funcs = {