Skip to content

Commit e6b256c

Browse files
align flows
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent d5ccf13 commit e6b256c

File tree

2 files changed

+110
-99
lines changed

2 files changed

+110
-99
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 103 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from __future__ import annotations
88

9+
import io
910
import logging
11+
from copy import deepcopy
1012
from typing import (
1113
List,
1214
Optional,
@@ -20,6 +22,9 @@
2022
from databricks.sql.backend.sea.result_set import SeaResultSet
2123

2224
from databricks.sql.backend.types import ExecuteResponse
25+
from databricks.sql.backend.sea.models.base import ResultData
26+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
27+
from databricks.sql.utils import CloudFetchQueue, ArrowQueue
2328

2429
try:
2530
import pyarrow
@@ -37,32 +42,18 @@ class ResultSetFilter:
3742
"""
3843

3944
@staticmethod
40-
def _filter_sea_result_set(
41-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
42-
) -> SeaResultSet:
45+
def _create_execute_response(result_set: SeaResultSet) -> ExecuteResponse:
4346
"""
44-
Filter a SEA result set using the provided filter function.
47+
Create an ExecuteResponse with parameters from the original result set.
4548
4649
Args:
47-
result_set: The SEA result set to filter
48-
filter_func: Function that takes a row and returns True if the row should be included
50+
result_set: Original result set to copy parameters from
4951
5052
Returns:
51-
A filtered SEA result set
53+
ExecuteResponse: New execute response object
5254
"""
53-
54-
# Get all remaining rows
55-
all_rows = result_set.results.remaining_rows()
56-
57-
# Filter rows
58-
filtered_rows = [row for row in all_rows if filter_func(row)]
59-
60-
# Reuse the command_id from the original result set
61-
command_id = result_set.command_id
62-
63-
# Create an ExecuteResponse for the filtered data
64-
execute_response = ExecuteResponse(
65-
command_id=command_id,
55+
return ExecuteResponse(
56+
command_id=result_set.command_id,
6657
status=result_set.status,
6758
description=result_set.description,
6859
has_been_closed_server_side=result_set.has_been_closed_server_side,
@@ -71,29 +62,99 @@ def _filter_sea_result_set(
7162
is_staging_operation=False,
7263
)
7364

74-
# Create a new ResultData object with filtered data
75-
from databricks.sql.backend.sea.models.base import ResultData
65+
@staticmethod
66+
def _create_filtered_manifest(result_set: SeaResultSet, new_row_count: int):
67+
"""
68+
Create a copy of the manifest with updated row count.
69+
70+
Args:
71+
result_set: Original result set to copy manifest from
72+
new_row_count: New total row count for filtered data
7673
77-
result_data = ResultData(data=filtered_rows, external_links=None)
74+
Returns:
75+
Updated manifest copy
76+
"""
77+
filtered_manifest = deepcopy(result_set.manifest)
78+
filtered_manifest.total_row_count = new_row_count
79+
return filtered_manifest
80+
81+
@staticmethod
82+
def _create_filtered_result_set(
83+
result_set: SeaResultSet,
84+
result_data: ResultData,
85+
row_count: int,
86+
) -> "SeaResultSet":
87+
"""
88+
Create a new filtered SeaResultSet with the provided data.
89+
90+
Args:
91+
result_set: Original result set to copy parameters from
92+
result_data: New result data for the filtered set
93+
row_count: Number of rows in the filtered data
7894
79-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
95+
Returns:
96+
New filtered SeaResultSet
97+
"""
8098
from databricks.sql.backend.sea.result_set import SeaResultSet
8199

82-
# Create a new SeaResultSet with the filtered data
83-
manifest = result_set.manifest
84-
manifest.total_row_count = len(filtered_rows)
100+
execute_response = ResultSetFilter._create_execute_response(result_set)
101+
filtered_manifest = ResultSetFilter._create_filtered_manifest(
102+
result_set, row_count
103+
)
85104

86-
filtered_result_set = SeaResultSet(
105+
return SeaResultSet(
87106
connection=result_set.connection,
88107
execute_response=execute_response,
89108
sea_client=cast(SeaDatabricksClient, result_set.backend),
90109
result_data=result_data,
91-
manifest=manifest,
110+
manifest=filtered_manifest,
92111
buffer_size_bytes=result_set.buffer_size_bytes,
93112
arraysize=result_set.arraysize,
94113
)
95114

96-
return filtered_result_set
115+
@staticmethod
116+
def _validate_column_index(result_set: SeaResultSet, column_index: int) -> str:
117+
"""
118+
Validate column index and return the column name.
119+
120+
Args:
121+
result_set: Result set to validate against
122+
column_index: Index of the column to validate
123+
124+
Returns:
125+
str: Column name at the specified index
126+
127+
Raises:
128+
ValueError: If column index is out of bounds
129+
"""
130+
if column_index >= len(result_set.description):
131+
raise ValueError(f"Column index {column_index} is out of bounds")
132+
return result_set.description[column_index][0]
133+
134+
@staticmethod
135+
def _filter_json_table(
136+
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
137+
) -> SeaResultSet:
138+
"""
139+
Filter a SEA result set using the provided filter function.
140+
141+
Args:
142+
result_set: The SEA result set to filter
143+
filter_func: Function that takes a row and returns True if the row should be included
144+
145+
Returns:
146+
A filtered SEA result set
147+
"""
148+
# Get all remaining rows and filter them
149+
all_rows = result_set.results.remaining_rows()
150+
filtered_rows = [row for row in all_rows if filter_func(row)]
151+
152+
# Create ResultData with filtered rows
153+
result_data = ResultData(data=filtered_rows, external_links=None)
154+
155+
return ResultSetFilter._create_filtered_result_set(
156+
result_set, result_data, len(filtered_rows)
157+
)
97158

98159
@staticmethod
99160
def _filter_arrow_table(
@@ -112,7 +173,6 @@ def _filter_arrow_table(
112173
Returns:
113174
A filtered PyArrow table
114175
"""
115-
116176
if not pyarrow:
117177
raise ImportError("PyArrow is required for Arrow table filtering")
118178

@@ -143,78 +203,34 @@ def _filter_arrow_result_set(
143203
Returns:
144204
A filtered SEA result set
145205
"""
206+
# Validate column index and get column name
207+
column_name = ResultSetFilter._validate_column_index(result_set, column_index)
146208

147-
# Get all remaining rows as Arrow table
209+
# Get all remaining rows as Arrow table and filter it
148210
arrow_table = result_set.results.remaining_rows()
149-
150-
# Get the column name from the description
151-
if column_index >= len(result_set.description):
152-
raise ValueError(f"Column index {column_index} is out of bounds")
153-
154-
column_name = result_set.description[column_index][0]
155-
156-
# Filter the Arrow table
157211
filtered_table = ResultSetFilter._filter_arrow_table(
158212
arrow_table, column_name, allowed_values
159213
)
160214

161-
# Create a new result set with filtered data
162-
command_id = result_set.command_id
163-
164-
# Create an ExecuteResponse for the filtered data
165-
execute_response = ExecuteResponse(
166-
command_id=command_id,
167-
status=result_set.status,
168-
description=result_set.description,
169-
has_been_closed_server_side=result_set.has_been_closed_server_side,
170-
lz4_compressed=result_set.lz4_compressed,
171-
arrow_schema_bytes=result_set._arrow_schema_bytes,
172-
is_staging_operation=False,
173-
)
174-
175-
# Create ResultData with the filtered arrow table as attachment
176-
# This mimics the hybrid disposition flow in build_queue
177-
from databricks.sql.backend.sea.models.base import ResultData
178-
from databricks.sql.backend.sea.result_set import SeaResultSet
179-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
180-
import io
181-
182-
# Convert the filtered table to Arrow stream format
215+
# Convert the filtered table to Arrow stream format for ResultData
183216
sink = io.BytesIO()
184217
with pyarrow.ipc.new_stream(sink, filtered_table.schema) as writer:
185218
writer.write_table(filtered_table)
186219
arrow_stream_bytes = sink.getvalue()
187220

188221
# Create ResultData with attachment containing the filtered data
189-
filtered_result_data = ResultData(
222+
result_data = ResultData(
190223
data=None, # No JSON data
191224
external_links=None, # No external links
192225
attachment=arrow_stream_bytes, # Arrow data as attachment
193226
)
194227

195-
# Update manifest to reflect new row count
196-
manifest = result_set.manifest
197-
# Create a copy of the manifest to avoid modifying the original
198-
from copy import deepcopy
199-
200-
filtered_manifest = deepcopy(manifest)
201-
filtered_manifest.total_row_count = filtered_table.num_rows
202-
203-
# Create a new SeaResultSet with the filtered data
204-
filtered_result_set = SeaResultSet(
205-
connection=result_set.connection,
206-
execute_response=execute_response,
207-
sea_client=cast(SeaDatabricksClient, result_set.backend),
208-
result_data=filtered_result_data,
209-
manifest=filtered_manifest,
210-
buffer_size_bytes=result_set.buffer_size_bytes,
211-
arraysize=result_set.arraysize,
228+
return ResultSetFilter._create_filtered_result_set(
229+
result_set, result_data, filtered_table.num_rows
212230
)
213231

214-
return filtered_result_set
215-
216232
@staticmethod
217-
def filter_by_column_values(
233+
def _filter_json_result_set(
218234
result_set: SeaResultSet,
219235
column_index: int,
220236
allowed_values: List[str],
@@ -237,7 +253,7 @@ def filter_by_column_values(
237253
if not case_sensitive:
238254
allowed_values = [v.upper() for v in allowed_values]
239255

240-
return ResultSetFilter._filter_sea_result_set(
256+
return ResultSetFilter._filter_json_table(
241257
result_set,
242258
lambda row: (
243259
len(row) > column_index
@@ -268,24 +284,19 @@ def filter_tables_by_type(
268284
Returns:
269285
A filtered result set containing only tables of the specified types
270286
"""
271-
272287
# Default table types if none specified
273288
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
274-
valid_types = (
275-
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
276-
)
289+
valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
277290

278291
# Check if we have an Arrow table (cloud fetch) or JSON data
279-
from databricks.sql.utils import CloudFetchQueue, ArrowQueue
280-
292+
# Table type is the 6th column (index 5)
281293
if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)):
282294
# For Arrow tables, we need to handle filtering differently
283295
return ResultSetFilter._filter_arrow_result_set(
284296
result_set, column_index=5, allowed_values=valid_types
285297
)
286298
else:
287299
# For JSON data, use the existing filter method
288-
# Table type is the 6th column (index 5)
289-
return ResultSetFilter.filter_by_column_values(
300+
return ResultSetFilter._filter_json_result_set(
290301
result_set, 5, valid_types, case_sensitive=True
291302
)

tests/unit/test_filters.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def setUp(self):
6868
self.mock_sea_result_set.has_been_closed_server_side = False
6969
self.mock_sea_result_set._arrow_schema_bytes = None
7070

71-
def test_filter_by_column_values(self):
71+
def test__filter_json_result_set(self):
7272
"""Test filtering by column values with various options."""
7373
# Case 1: Case-sensitive filtering
7474
allowed_values = ["table1", "table3"]
@@ -82,8 +82,8 @@ def test_filter_by_column_values(self):
8282
mock_instance = MagicMock()
8383
mock_sea_result_set_class.return_value = mock_instance
8484

85-
# Call filter_by_column_values on the table_name column (index 2)
86-
result = ResultSetFilter.filter_by_column_values(
85+
# Call _filter_json_result_set on the table_name column (index 2)
86+
result = ResultSetFilter._filter_json_result_set(
8787
self.mock_sea_result_set, 2, allowed_values, case_sensitive=True
8888
)
8989

@@ -109,8 +109,8 @@ def test_filter_by_column_values(self):
109109
mock_instance = MagicMock()
110110
mock_sea_result_set_class.return_value = mock_instance
111111

112-
# Call filter_by_column_values with case-insensitive matching
113-
result = ResultSetFilter.filter_by_column_values(
112+
# Call _filter_json_result_set with case-insensitive matching
113+
result = ResultSetFilter._filter_json_result_set(
114114
self.mock_sea_result_set,
115115
2,
116116
["TABLE1", "TABLE3"],
@@ -128,7 +128,7 @@ def test_filter_tables_by_type(self):
128128

129129
self.mock_sea_result_set.results = JsonQueue([])
130130

131-
with patch.object(ResultSetFilter, "filter_by_column_values") as mock_filter:
131+
with patch.object(ResultSetFilter, "_filter_json_result_set") as mock_filter:
132132
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
133133
args, kwargs = mock_filter.call_args
134134
self.assertEqual(args[0], self.mock_sea_result_set)
@@ -137,7 +137,7 @@ def test_filter_tables_by_type(self):
137137
self.assertEqual(kwargs.get("case_sensitive"), True)
138138

139139
# Case 2: Default table types (None or empty list)
140-
with patch.object(ResultSetFilter, "filter_by_column_values") as mock_filter:
140+
with patch.object(ResultSetFilter, "_filter_json_result_set") as mock_filter:
141141
# Test with None
142142
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
143143
args, kwargs = mock_filter.call_args

0 commit comments

Comments
 (0)