|
13 | 13 | List, |
14 | 14 | Optional, |
15 | 15 | Any, |
16 | | - Callable, |
17 | 16 | cast, |
18 | 17 | TYPE_CHECKING, |
19 | 18 | ) |
@@ -131,31 +130,6 @@ def _validate_column_index(result_set: SeaResultSet, column_index: int) -> str: |
131 | 130 | raise ValueError(f"Column index {column_index} is out of bounds") |
132 | 131 | return result_set.description[column_index][0] |
133 | 132 |
|
134 | | - @staticmethod |
135 | | - def _filter_json_table( |
136 | | - result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] |
137 | | - ) -> SeaResultSet: |
138 | | - """ |
139 | | - Filter a SEA result set using the provided filter function. |
140 | | -
|
141 | | - Args: |
142 | | - result_set: The SEA result set to filter |
143 | | - filter_func: Function that takes a row and returns True if the row should be included |
144 | | -
|
145 | | - Returns: |
146 | | - A filtered SEA result set |
147 | | - """ |
148 | | - # Get all remaining rows and filter them |
149 | | - all_rows = result_set.results.remaining_rows() |
150 | | - filtered_rows = [row for row in all_rows if filter_func(row)] |
151 | | - |
152 | | - # Create ResultData with filtered rows |
153 | | - result_data = ResultData(data=filtered_rows, external_links=None) |
154 | | - |
155 | | - return ResultSetFilter._create_filtered_result_set( |
156 | | - result_set, result_data, len(filtered_rows) |
157 | | - ) |
158 | | - |
159 | 133 | @staticmethod |
160 | 134 | def _filter_arrow_table( |
161 | 135 | table: Any, # pyarrow.Table |
@@ -248,22 +222,36 @@ def _filter_json_result_set( |
248 | 222 | Returns: |
249 | 223 | A filtered result set |
250 | 224 | """ |
| 225 | + # Validate column index (optional - not in arrow version but good practice) |
| 226 | + if column_index >= len(result_set.description): |
| 227 | + raise ValueError(f"Column index {column_index} is out of bounds") |
| 228 | + |
| 229 | + # Extract rows |
| 230 | + all_rows = result_set.results.remaining_rows() |
251 | 231 |
|
252 | | - # Convert to uppercase for case-insensitive comparison if needed |
| 232 | + # Convert allowed values if case-insensitive |
253 | 233 | if not case_sensitive: |
254 | 234 | allowed_values = [v.upper() for v in allowed_values] |
| 235 | + # Helper lambda to get column value based on case sensitivity |
| 236 | + get_column_value = ( |
| 237 | + lambda row: row[column_index].upper() |
| 238 | + if not case_sensitive |
| 239 | + else row[column_index] |
| 240 | + ) |
| 241 | + |
| 242 | + # Filter rows based on allowed values |
| 243 | + filtered_rows = [ |
| 244 | + row |
| 245 | + for row in all_rows |
| 246 | + if len(row) > column_index and get_column_value(row) in allowed_values |
| 247 | + ] |
| 248 | + |
| 249 | + # Create filtered result set |
| 250 | + result_data = ResultData(data=filtered_rows, external_links=None) |
255 | 251 |
|
256 | | - return ResultSetFilter._filter_json_table( |
257 | | - result_set, |
258 | | - lambda row: ( |
259 | | - len(row) > column_index |
260 | | - and ( |
261 | | - row[column_index].upper() |
262 | | - if not case_sensitive |
263 | | - else row[column_index] |
264 | | - ) |
265 | | - in allowed_values |
266 | | - ), |
| 252 | + # Return |
| 253 | + return ResultSetFilter._create_filtered_result_set( |
| 254 | + result_set, result_data, len(filtered_rows) |
267 | 255 | ) |
268 | 256 |
|
269 | 257 | @staticmethod |
|
0 commit comments