Skip to content

Commit fb87e86

Browse files
committed
final polish of the helper function
1 parent 96b4eba commit fb87e86

File tree

1 file changed

+43
-26
lines changed

1 file changed

+43
-26
lines changed

bigframes/core/indexes/base.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -252,32 +252,33 @@ def query_job(self) -> bigquery.QueryJob:
252252
self._query_job = query_job
253253
return self._query_job
254254

255-
def get_loc(
256-
self, key: typing.Any
257-
) -> typing.Union[int, slice, "bigframes.series.Series"]:
255+
def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
258256
"""Get integer location, slice or boolean mask for requested label.
257+
259258
Args:
260-
key: The label to search for in the index.
259+
key:
260+
The label to search for in the index.
261+
261262
Returns:
262263
An integer, slice, or boolean mask representing the location(s) of the key.
264+
263265
Raises:
264266
NotImplementedError: If the index has more than one level.
265267
KeyError: If the key is not found in the index.
266268
"""
267-
268269
if self.nlevels != 1:
269270
raise NotImplementedError("get_loc only supports single-level indexes")
270271

271272
# Get the index column from the block
272273
index_column = self._block.index_columns[0]
273274

274-
# Apply row numbering to the original data - inline single-use variables
275-
row_num_col_id = ids.ColumnId.unique()
275+
# Apply row numbering to the original data
276+
row_number_column_id = ids.ColumnId.unique()
276277
window_node = nodes.WindowOpNode(
277278
child=self._block._expr.node,
278279
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
279280
window_spec=window_spec.unbound(),
280-
output_name=row_num_col_id,
281+
output_name=row_number_column_id,
281282
never_skip_nulls=True,
282283
)
283284

@@ -299,7 +300,9 @@ def get_loc(
299300
filtered_block = windowed_block.filter_by_id(match_col_id)
300301

301302
# Check if key exists at all by counting on the filtered block
302-
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(row_num_col_id.name))
303+
count_agg = ex.UnaryAggregation(
304+
agg_ops.count_op, ex.deref(row_number_column_id.name)
305+
)
303306
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
304307
count_scalar = self._block.session._executor.execute(
305308
count_result
@@ -310,38 +313,52 @@ def get_loc(
310313

311314
# If only one match, return integer position
312315
if count_scalar == 1:
313-
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(row_num_col_id.name))
316+
min_agg = ex.UnaryAggregation(
317+
agg_ops.min_op, ex.deref(row_number_column_id.name)
318+
)
314319
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
315320
position_scalar = self._block.session._executor.execute(
316321
position_result
317322
).to_py_scalar()
318323
return int(position_scalar)
319324

320-
# Multiple matches - need to determine if monotonic or not
325+
# Handle multiple matches based on index monotonicity
321326
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
322327
if is_monotonic:
323-
return self._get_monotonic_slice(filtered_block, row_num_col_id)
328+
return self._get_monotonic_slice(filtered_block, row_number_column_id)
324329
else:
325330
# Return boolean mask for non-monotonic duplicates
326331
mask_block = windowed_block.select_columns([match_col_id])
327332
return bigframes.series.Series(mask_block)
328333

329-
def _get_monotonic_slice(self, filtered_block, row_num_col_id):
330-
"""Helper method to get slice for monotonic duplicates with optimized query."""
331-
# Combine min and max aggregations into single query using to_pandas()
332-
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(row_num_col_id.name))
333-
max_agg = ex.UnaryAggregation(agg_ops.max_op, ex.deref(row_num_col_id.name))
334-
combined_result = filtered_block._expr.aggregate(
335-
[(min_agg, "min_pos"), (max_agg, "max_pos")]
336-
)
334+
def _get_monotonic_slice(
335+
self, filtered_block, row_number_column_id: "ids.ColumnId"
336+
) -> slice:
337+
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
338+
# Combine min and max aggregations into a single query for efficiency
339+
min_max_aggs = [
340+
(
341+
ex.UnaryAggregation(
342+
agg_ops.min_op, ex.deref(row_number_column_id.name)
343+
),
344+
"min_pos",
345+
),
346+
(
347+
ex.UnaryAggregation(
348+
agg_ops.max_op, ex.deref(row_number_column_id.name)
349+
),
350+
"max_pos",
351+
),
352+
]
353+
combined_result = filtered_block._expr.aggregate(min_max_aggs)
354+
355+
# Execute query and extract positions
337356
result_df = self._block.session._executor.execute(combined_result).to_pandas()
338-
min_pos = result_df["min_pos"].iloc[0]
339-
max_pos = result_df["max_pos"].iloc[0]
357+
min_pos = int(result_df["min_pos"].iloc[0])
358+
max_pos = int(result_df["max_pos"].iloc[0])
340359

341-
# Create slice
342-
start = int(min_pos)
343-
stop = int(max_pos) + 1 # exclusive
344-
return slice(start, stop, None)
360+
# Create slice (stop is exclusive)
361+
return slice(min_pos, max_pos + 1)
345362

346363
def __repr__(self) -> str:
347364
# Protect against errors with uninitialized Series. See:

0 commit comments

Comments
 (0)