@@ -144,14 +144,17 @@ def head(self, n: int = 5) -> df.DataFrame:
144144 )
145145
146146 def __iter__ (self ) -> Iterable [Tuple [blocks .Label , pd .DataFrame ]]:
147- # TODO: cache original block, clustered by column ids
147+ # Cache original block, clustered by column ids.
148+ # To force block.cached() to cluster by our by_col_ids,
149+ # we set those columns as the index. This also makes filtering
150+ # by our groupby key a bit easier with respect to fewer
151+ # cases to worry about (e.g. MultiIndex).
152+ original_index_labels = self ._block ._index_labels
153+ by_col_labels = self ._block ._get_labels_for_columns (self ._by_col_ids ).to_list ()
148154 block = self ._block .set_index (
149155 self ._by_col_ids ,
150- # TODO: do we need to keep the original index?
151156 drop = False ,
152- index_labels = self ._block ._get_labels_for_columns (
153- self ._by_col_ids
154- ).to_list (),
157+ index_labels = by_col_labels ,
155158 )
156159 block .cached (force = True )
157160
@@ -161,14 +164,9 @@ def __iter__(self) -> Iterable[Tuple[blocks.Label, pd.DataFrame]]:
161164 )
162165 for batch in keys_block .to_pandas_batches ():
163166 for key in batch .index :
164- # group_block = block
165- # for col in self._by_col_ids: # TODO: can't loop through key if only one by_col_id.
166-
167- #
168- # = block.project_expr(bigframes.core.expression.const(key, dtype=self._block._column_type(self._by_col_ids))
169- # ops.eq_op( ex.const(key)
170- # )
171- yield key , batch # TODO: filter clustered block by row
167+ yield key , df .DataFrame (block ).loc [key ].set_index (
168+ original_index_labels , drop = False
169+ )
172170
173171 def size (self ) -> typing .Union [df .DataFrame , series .Series ]:
174172 agg_block , _ = self ._block .aggregate_size (
0 commit comments