Skip to content

Commit 20c9fbd

Browse files
align flow of json with arrow
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent e6b256c commit 20c9fbd

File tree

1 file changed

+26
-38
lines changed

1 file changed

+26
-38
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
List,
1414
Optional,
1515
Any,
16-
Callable,
1716
cast,
1817
TYPE_CHECKING,
1918
)
@@ -131,31 +130,6 @@ def _validate_column_index(result_set: SeaResultSet, column_index: int) -> str:
131130
raise ValueError(f"Column index {column_index} is out of bounds")
132131
return result_set.description[column_index][0]
133132

134-
@staticmethod
135-
def _filter_json_table(
136-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
137-
) -> SeaResultSet:
138-
"""
139-
Filter a SEA result set using the provided filter function.
140-
141-
Args:
142-
result_set: The SEA result set to filter
143-
filter_func: Function that takes a row and returns True if the row should be included
144-
145-
Returns:
146-
A filtered SEA result set
147-
"""
148-
# Get all remaining rows and filter them
149-
all_rows = result_set.results.remaining_rows()
150-
filtered_rows = [row for row in all_rows if filter_func(row)]
151-
152-
# Create ResultData with filtered rows
153-
result_data = ResultData(data=filtered_rows, external_links=None)
154-
155-
return ResultSetFilter._create_filtered_result_set(
156-
result_set, result_data, len(filtered_rows)
157-
)
158-
159133
@staticmethod
160134
def _filter_arrow_table(
161135
table: Any, # pyarrow.Table
@@ -248,22 +222,36 @@ def _filter_json_result_set(
248222
Returns:
249223
A filtered result set
250224
"""
225+
# Validate column index (optional - not in arrow version but good practice)
226+
if column_index >= len(result_set.description):
227+
raise ValueError(f"Column index {column_index} is out of bounds")
228+
229+
# Extract rows
230+
all_rows = result_set.results.remaining_rows()
251231

252-
# Convert to uppercase for case-insensitive comparison if needed
232+
# Convert allowed values if case-insensitive
253233
if not case_sensitive:
254234
allowed_values = [v.upper() for v in allowed_values]
235+
# Helper lambda to get column value based on case sensitivity
236+
get_column_value = (
237+
lambda row: row[column_index].upper()
238+
if not case_sensitive
239+
else row[column_index]
240+
)
241+
242+
# Filter rows based on allowed values
243+
filtered_rows = [
244+
row
245+
for row in all_rows
246+
if len(row) > column_index and get_column_value(row) in allowed_values
247+
]
248+
249+
# Create filtered result set
250+
result_data = ResultData(data=filtered_rows, external_links=None)
255251

256-
return ResultSetFilter._filter_json_table(
257-
result_set,
258-
lambda row: (
259-
len(row) > column_index
260-
and (
261-
row[column_index].upper()
262-
if not case_sensitive
263-
else row[column_index]
264-
)
265-
in allowed_values
266-
),
252+
# Return
253+
return ResultSetFilter._create_filtered_result_set(
254+
result_set, result_data, len(filtered_rows)
267255
)
268256

269257
@staticmethod

0 commit comments

Comments
 (0)