@@ -252,32 +252,33 @@ def query_job(self) -> bigquery.QueryJob:
252252 self ._query_job = query_job
253253 return self ._query_job
254254
255- def get_loc (
256- self , key : typing .Any
257- ) -> typing .Union [int , slice , "bigframes.series.Series" ]:
255+ def get_loc (self , key ) -> typing .Union [int , slice , "bigframes.series.Series" ]:
258256 """Get integer location, slice or boolean mask for requested label.
257+
259258 Args:
260- key: The label to search for in the index.
259+ key:
260+ The label to search for in the index.
261+
261262 Returns:
262263 An integer, slice, or boolean mask representing the location(s) of the key.
264+
263265 Raises:
264266 NotImplementedError: If the index has more than one level.
265267 KeyError: If the key is not found in the index.
266268 """
267-
268269 if self .nlevels != 1 :
269270 raise NotImplementedError ("get_loc only supports single-level indexes" )
270271
271272 # Get the index column from the block
272273 index_column = self ._block .index_columns [0 ]
273274
274- # Apply row numbering to the original data - inline single-use variables
275- row_num_col_id = ids .ColumnId .unique ()
275+ # Apply row numbering to the original data
276+ row_number_column_id = ids .ColumnId .unique ()
276277 window_node = nodes .WindowOpNode (
277278 child = self ._block ._expr .node ,
278279 expression = ex .NullaryAggregation (agg_ops .RowNumberOp ()),
279280 window_spec = window_spec .unbound (),
280- output_name = row_num_col_id ,
281+ output_name = row_number_column_id ,
281282 never_skip_nulls = True ,
282283 )
283284
@@ -299,7 +300,9 @@ def get_loc(
299300 filtered_block = windowed_block .filter_by_id (match_col_id )
300301
301302 # Check if key exists at all by counting on the filtered block
302- count_agg = ex .UnaryAggregation (agg_ops .count_op , ex .deref (row_num_col_id .name ))
303+ count_agg = ex .UnaryAggregation (
304+ agg_ops .count_op , ex .deref (row_number_column_id .name )
305+ )
303306 count_result = filtered_block ._expr .aggregate ([(count_agg , "count" )])
304307 count_scalar = self ._block .session ._executor .execute (
305308 count_result
@@ -310,38 +313,52 @@ def get_loc(
310313
311314 # If only one match, return integer position
312315 if count_scalar == 1 :
313- min_agg = ex .UnaryAggregation (agg_ops .min_op , ex .deref (row_num_col_id .name ))
316+ min_agg = ex .UnaryAggregation (
317+ agg_ops .min_op , ex .deref (row_number_column_id .name )
318+ )
314319 position_result = filtered_block ._expr .aggregate ([(min_agg , "position" )])
315320 position_scalar = self ._block .session ._executor .execute (
316321 position_result
317322 ).to_py_scalar ()
318323 return int (position_scalar )
319324
320- # Multiple matches - need to determine if monotonic or not
325+ # Handle multiple matches based on index monotonicity
321326 is_monotonic = self .is_monotonic_increasing or self .is_monotonic_decreasing
322327 if is_monotonic :
323- return self ._get_monotonic_slice (filtered_block , row_num_col_id )
328+ return self ._get_monotonic_slice (filtered_block , row_number_column_id )
324329 else :
325330 # Return boolean mask for non-monotonic duplicates
326331 mask_block = windowed_block .select_columns ([match_col_id ])
327332 return bigframes .series .Series (mask_block )
328333
329- def _get_monotonic_slice (self , filtered_block , row_num_col_id ):
330- """Helper method to get slice for monotonic duplicates with optimized query."""
331- # Combine min and max aggregations into single query using to_pandas()
332- min_agg = ex .UnaryAggregation (agg_ops .min_op , ex .deref (row_num_col_id .name ))
333- max_agg = ex .UnaryAggregation (agg_ops .max_op , ex .deref (row_num_col_id .name ))
334- combined_result = filtered_block ._expr .aggregate (
335- [(min_agg , "min_pos" ), (max_agg , "max_pos" )]
336- )
334+ def _get_monotonic_slice (
335+ self , filtered_block , row_number_column_id : "ids.ColumnId"
336+ ) -> slice :
337+ """Helper method to get a slice for monotonic duplicates with an optimized query."""
338+ # Combine min and max aggregations into a single query for efficiency
339+ min_max_aggs = [
340+ (
341+ ex .UnaryAggregation (
342+ agg_ops .min_op , ex .deref (row_number_column_id .name )
343+ ),
344+ "min_pos" ,
345+ ),
346+ (
347+ ex .UnaryAggregation (
348+ agg_ops .max_op , ex .deref (row_number_column_id .name )
349+ ),
350+ "max_pos" ,
351+ ),
352+ ]
353+ combined_result = filtered_block ._expr .aggregate (min_max_aggs )
354+
355+ # Execute query and extract positions
337356 result_df = self ._block .session ._executor .execute (combined_result ).to_pandas ()
338- min_pos = result_df ["min_pos" ].iloc [0 ]
339- max_pos = result_df ["max_pos" ].iloc [0 ]
357+ min_pos = int ( result_df ["min_pos" ].iloc [0 ])
358+ max_pos = int ( result_df ["max_pos" ].iloc [0 ])
340359
341- # Create slice
342- start = int (min_pos )
343- stop = int (max_pos ) + 1 # exclusive
344- return slice (start , stop , None )
360+ # Create slice (stop is exclusive)
361+ return slice (min_pos , max_pos + 1 )
345362
346363 def __repr__ (self ) -> str :
347364 # Protect against errors with uninitialized Series. See:
0 commit comments