66
77from __future__ import annotations
88
9+ import io
910import logging
11+ from copy import deepcopy
1012from typing import (
1113 List ,
1214 Optional ,
2022 from databricks .sql .backend .sea .result_set import SeaResultSet
2123
2224from databricks .sql .backend .types import ExecuteResponse
25+ from databricks .sql .backend .sea .models .base import ResultData
26+ from databricks .sql .backend .sea .backend import SeaDatabricksClient
27+ from databricks .sql .utils import CloudFetchQueue , ArrowQueue
2328
2429try :
2530 import pyarrow
@@ -37,32 +42,18 @@ class ResultSetFilter:
3742 """
3843
3944 @staticmethod
40- def _filter_sea_result_set (
41- result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
42- ) -> SeaResultSet :
45+ def _create_execute_response (result_set : SeaResultSet ) -> ExecuteResponse :
4346 """
44- Filter a SEA result set using the provided filter function .
47+ Create an ExecuteResponse with parameters from the original result set .
4548
4649 Args:
47- result_set: The SEA result set to filter
48- filter_func: Function that takes a row and returns True if the row should be included
50+ result_set: Original result set to copy parameters from
4951
5052 Returns:
51- A filtered SEA result set
53+ ExecuteResponse: New execute response object
5254 """
53-
54- # Get all remaining rows
55- all_rows = result_set .results .remaining_rows ()
56-
57- # Filter rows
58- filtered_rows = [row for row in all_rows if filter_func (row )]
59-
60- # Reuse the command_id from the original result set
61- command_id = result_set .command_id
62-
63- # Create an ExecuteResponse for the filtered data
64- execute_response = ExecuteResponse (
65- command_id = command_id ,
55+ return ExecuteResponse (
56+ command_id = result_set .command_id ,
6657 status = result_set .status ,
6758 description = result_set .description ,
6859 has_been_closed_server_side = result_set .has_been_closed_server_side ,
@@ -71,29 +62,99 @@ def _filter_sea_result_set(
7162 is_staging_operation = False ,
7263 )
7364
74- # Create a new ResultData object with filtered data
75- from databricks .sql .backend .sea .models .base import ResultData
65+ @staticmethod
66+ def _create_filtered_manifest (result_set : SeaResultSet , new_row_count : int ):
67+ """
68+ Create a copy of the manifest with updated row count.
69+
70+ Args:
71+ result_set: Original result set to copy manifest from
72+ new_row_count: New total row count for filtered data
7673
77- result_data = ResultData (data = filtered_rows , external_links = None )
74+ Returns:
75+ Updated manifest copy
76+ """
77+ filtered_manifest = deepcopy (result_set .manifest )
78+ filtered_manifest .total_row_count = new_row_count
79+ return filtered_manifest
80+
81+ @staticmethod
82+ def _create_filtered_result_set (
83+ result_set : SeaResultSet ,
84+ result_data : ResultData ,
85+ row_count : int ,
86+ ) -> "SeaResultSet" :
87+ """
88+ Create a new filtered SeaResultSet with the provided data.
89+
90+ Args:
91+ result_set: Original result set to copy parameters from
92+ result_data: New result data for the filtered set
93+ row_count: Number of rows in the filtered data
7894
79- from databricks .sql .backend .sea .backend import SeaDatabricksClient
95+ Returns:
96+ New filtered SeaResultSet
97+ """
8098 from databricks .sql .backend .sea .result_set import SeaResultSet
8199
82- # Create a new SeaResultSet with the filtered data
83- manifest = result_set .manifest
84- manifest .total_row_count = len (filtered_rows )
100+ execute_response = ResultSetFilter ._create_execute_response (result_set )
101+ filtered_manifest = ResultSetFilter ._create_filtered_manifest (
102+ result_set , row_count
103+ )
85104
86- filtered_result_set = SeaResultSet (
105+ return SeaResultSet (
87106 connection = result_set .connection ,
88107 execute_response = execute_response ,
89108 sea_client = cast (SeaDatabricksClient , result_set .backend ),
90109 result_data = result_data ,
91- manifest = manifest ,
110+ manifest = filtered_manifest ,
92111 buffer_size_bytes = result_set .buffer_size_bytes ,
93112 arraysize = result_set .arraysize ,
94113 )
95114
96- return filtered_result_set
115+ @staticmethod
116+ def _validate_column_index (result_set : SeaResultSet , column_index : int ) -> str :
117+ """
118+ Validate column index and return the column name.
119+
120+ Args:
121+ result_set: Result set to validate against
122+ column_index: Index of the column to validate
123+
124+ Returns:
125+ str: Column name at the specified index
126+
127+ Raises:
128+ ValueError: If column index is out of bounds
129+ """
130+ if column_index >= len (result_set .description ):
131+ raise ValueError (f"Column index { column_index } is out of bounds" )
132+ return result_set .description [column_index ][0 ]
133+
134+ @staticmethod
135+ def _filter_json_table (
136+ result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
137+ ) -> SeaResultSet :
138+ """
139+ Filter a SEA result set using the provided filter function.
140+
141+ Args:
142+ result_set: The SEA result set to filter
143+ filter_func: Function that takes a row and returns True if the row should be included
144+
145+ Returns:
146+ A filtered SEA result set
147+ """
148+ # Get all remaining rows and filter them
149+ all_rows = result_set .results .remaining_rows ()
150+ filtered_rows = [row for row in all_rows if filter_func (row )]
151+
152+ # Create ResultData with filtered rows
153+ result_data = ResultData (data = filtered_rows , external_links = None )
154+
155+ return ResultSetFilter ._create_filtered_result_set (
156+ result_set , result_data , len (filtered_rows )
157+ )
97158
98159 @staticmethod
99160 def _filter_arrow_table (
@@ -112,7 +173,6 @@ def _filter_arrow_table(
112173 Returns:
113174 A filtered PyArrow table
114175 """
115-
116176 if not pyarrow :
117177 raise ImportError ("PyArrow is required for Arrow table filtering" )
118178
@@ -143,78 +203,34 @@ def _filter_arrow_result_set(
143203 Returns:
144204 A filtered SEA result set
145205 """
206+ # Validate column index and get column name
207+ column_name = ResultSetFilter ._validate_column_index (result_set , column_index )
146208
147- # Get all remaining rows as Arrow table
209+ # Get all remaining rows as Arrow table and filter it
148210 arrow_table = result_set .results .remaining_rows ()
149-
150- # Get the column name from the description
151- if column_index >= len (result_set .description ):
152- raise ValueError (f"Column index { column_index } is out of bounds" )
153-
154- column_name = result_set .description [column_index ][0 ]
155-
156- # Filter the Arrow table
157211 filtered_table = ResultSetFilter ._filter_arrow_table (
158212 arrow_table , column_name , allowed_values
159213 )
160214
161- # Create a new result set with filtered data
162- command_id = result_set .command_id
163-
164- # Create an ExecuteResponse for the filtered data
165- execute_response = ExecuteResponse (
166- command_id = command_id ,
167- status = result_set .status ,
168- description = result_set .description ,
169- has_been_closed_server_side = result_set .has_been_closed_server_side ,
170- lz4_compressed = result_set .lz4_compressed ,
171- arrow_schema_bytes = result_set ._arrow_schema_bytes ,
172- is_staging_operation = False ,
173- )
174-
175- # Create ResultData with the filtered arrow table as attachment
176- # This mimics the hybrid disposition flow in build_queue
177- from databricks .sql .backend .sea .models .base import ResultData
178- from databricks .sql .backend .sea .result_set import SeaResultSet
179- from databricks .sql .backend .sea .backend import SeaDatabricksClient
180- import io
181-
182- # Convert the filtered table to Arrow stream format
215+ # Convert the filtered table to Arrow stream format for ResultData
183216 sink = io .BytesIO ()
184217 with pyarrow .ipc .new_stream (sink , filtered_table .schema ) as writer :
185218 writer .write_table (filtered_table )
186219 arrow_stream_bytes = sink .getvalue ()
187220
188221 # Create ResultData with attachment containing the filtered data
189- filtered_result_data = ResultData (
222+ result_data = ResultData (
190223 data = None , # No JSON data
191224 external_links = None , # No external links
192225 attachment = arrow_stream_bytes , # Arrow data as attachment
193226 )
194227
195- # Update manifest to reflect new row count
196- manifest = result_set .manifest
197- # Create a copy of the manifest to avoid modifying the original
198- from copy import deepcopy
199-
200- filtered_manifest = deepcopy (manifest )
201- filtered_manifest .total_row_count = filtered_table .num_rows
202-
203- # Create a new SeaResultSet with the filtered data
204- filtered_result_set = SeaResultSet (
205- connection = result_set .connection ,
206- execute_response = execute_response ,
207- sea_client = cast (SeaDatabricksClient , result_set .backend ),
208- result_data = filtered_result_data ,
209- manifest = filtered_manifest ,
210- buffer_size_bytes = result_set .buffer_size_bytes ,
211- arraysize = result_set .arraysize ,
228+ return ResultSetFilter ._create_filtered_result_set (
229+ result_set , result_data , filtered_table .num_rows
212230 )
213231
214- return filtered_result_set
215-
216232 @staticmethod
217- def filter_by_column_values (
233+ def _filter_json_result_set (
218234 result_set : SeaResultSet ,
219235 column_index : int ,
220236 allowed_values : List [str ],
@@ -237,7 +253,7 @@ def filter_by_column_values(
237253 if not case_sensitive :
238254 allowed_values = [v .upper () for v in allowed_values ]
239255
240- return ResultSetFilter ._filter_sea_result_set (
256+ return ResultSetFilter ._filter_json_table (
241257 result_set ,
242258 lambda row : (
243259 len (row ) > column_index
@@ -268,24 +284,19 @@ def filter_tables_by_type(
268284 Returns:
269285 A filtered result set containing only tables of the specified types
270286 """
271-
272287 # Default table types if none specified
273288 DEFAULT_TABLE_TYPES = ["TABLE" , "VIEW" , "SYSTEM TABLE" ]
274- valid_types = (
275- table_types if table_types and len (table_types ) > 0 else DEFAULT_TABLE_TYPES
276- )
289+ valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
277290
278291 # Check if we have an Arrow table (cloud fetch) or JSON data
279- from databricks .sql .utils import CloudFetchQueue , ArrowQueue
280-
292+ # Table type is the 6th column (index 5)
281293 if isinstance (result_set .results , (CloudFetchQueue , ArrowQueue )):
282294 # For Arrow tables, we need to handle filtering differently
283295 return ResultSetFilter ._filter_arrow_result_set (
284296 result_set , column_index = 5 , allowed_values = valid_types
285297 )
286298 else :
287299 # For JSON data, use the existing filter method
288- # Table type is the 6th column (index 5)
289- return ResultSetFilter .filter_by_column_values (
300+ return ResultSetFilter ._filter_json_result_set (
290301 result_set , 5 , valid_types , case_sensitive = True
291302 )
0 commit comments