Skip to content

Commit 89de17a

Browse files
preliminary large metadata results
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 83e45ae commit 89de17a

File tree

4 files changed

+257
-37
lines changed

4 files changed

+257
-37
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157
)
158158

159159
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
160+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
160161

161162
# Extract warehouse ID from http_path
162163
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -688,7 +689,7 @@ def get_catalogs(
688689
max_bytes=max_bytes,
689690
lz4_compression=False,
690691
cursor=cursor,
691-
use_cloud_fetch=False,
692+
use_cloud_fetch=self.use_cloud_fetch,
692693
parameters=[],
693694
async_op=False,
694695
enforce_embedded_schema_correctness=False,
@@ -721,7 +722,7 @@ def get_schemas(
721722
max_bytes=max_bytes,
722723
lz4_compression=False,
723724
cursor=cursor,
724-
use_cloud_fetch=False,
725+
use_cloud_fetch=self.use_cloud_fetch,
725726
parameters=[],
726727
async_op=False,
727728
enforce_embedded_schema_correctness=False,
@@ -762,7 +763,7 @@ def get_tables(
762763
max_bytes=max_bytes,
763764
lz4_compression=False,
764765
cursor=cursor,
765-
use_cloud_fetch=False,
766+
use_cloud_fetch=self.use_cloud_fetch,
766767
parameters=[],
767768
async_op=False,
768769
enforce_embedded_schema_correctness=False,
@@ -809,7 +810,7 @@ def get_columns(
809810
max_bytes=max_bytes,
810811
lz4_compression=False,
811812
cursor=cursor,
812-
use_cloud_fetch=False,
813+
use_cloud_fetch=self.use_cloud_fetch,
813814
parameters=[],
814815
async_op=False,
815816
enforce_embedded_schema_correctness=False,

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121

2222
from databricks.sql.backend.types import ExecuteResponse
2323

24+
try:
25+
import pyarrow
26+
import pyarrow.compute as pc
27+
except ImportError:
28+
pyarrow = None
29+
pc = None
30+
2431
logger = logging.getLogger(__name__)
2532

2633

@@ -88,6 +95,121 @@ def _filter_sea_result_set(
8895

8996
return filtered_result_set
9097

98+
@staticmethod
99+
def _filter_arrow_table(
100+
table: Any, # pyarrow.Table
101+
column_name: str,
102+
allowed_values: List[str],
103+
) -> Any: # returns pyarrow.Table
104+
"""
105+
Filter a PyArrow table by column values.
106+
107+
Args:
108+
table: The PyArrow table to filter
109+
column_name: The name of the column to filter on
110+
allowed_values: List of allowed values for the column
111+
112+
Returns:
113+
A filtered PyArrow table
114+
"""
115+
116+
if not pyarrow:
117+
raise ImportError("PyArrow is required for Arrow table filtering")
118+
119+
# Convert allowed_values to PyArrow Array for better performance
120+
allowed_array = pyarrow.array(allowed_values)
121+
122+
# Construct a boolean mask: True where column is in allowed_list
123+
mask = pc.is_in(table[column_name], value_set=allowed_array)
124+
return table.filter(mask)
125+
126+
@staticmethod
127+
def _filter_arrow_result_set(
128+
result_set: SeaResultSet,
129+
column_index: int,
130+
allowed_values: List[str],
131+
) -> SeaResultSet:
132+
"""
133+
Filter a SEA result set that contains Arrow tables.
134+
135+
Args:
136+
result_set: The SEA result set to filter (containing Arrow data)
137+
column_index: The index of the column to filter on
138+
allowed_values: List of allowed values for the column
139+
140+
Returns:
141+
A filtered SEA result set
142+
"""
143+
144+
# Get all remaining rows as Arrow table
145+
arrow_table = result_set.results.remaining_rows()
146+
147+
# Get the column name from the description
148+
if column_index >= len(result_set.description):
149+
raise ValueError(f"Column index {column_index} is out of bounds")
150+
151+
column_name = result_set.description[column_index][0]
152+
153+
# Filter the Arrow table
154+
filtered_table = ResultSetFilter._filter_arrow_table(
155+
arrow_table, column_name, allowed_values
156+
)
157+
158+
# Create a new result set with filtered data
159+
command_id = result_set.command_id
160+
161+
# Create an ExecuteResponse for the filtered data
162+
execute_response = ExecuteResponse(
163+
command_id=command_id,
164+
status=result_set.status,
165+
description=result_set.description,
166+
has_been_closed_server_side=result_set.has_been_closed_server_side,
167+
lz4_compressed=result_set.lz4_compressed,
168+
arrow_schema_bytes=result_set._arrow_schema_bytes,
169+
is_staging_operation=False,
170+
)
171+
172+
# Create ResultData with the filtered arrow table as attachment
173+
# This mimics the hybrid disposition flow in build_queue
174+
from databricks.sql.backend.sea.models.base import ResultData
175+
from databricks.sql.backend.sea.result_set import SeaResultSet
176+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
177+
import io
178+
179+
# Convert the filtered table to Arrow stream format
180+
sink = io.BytesIO()
181+
with pyarrow.ipc.new_stream(sink, filtered_table.schema) as writer:
182+
writer.write_table(filtered_table)
183+
arrow_stream_bytes = sink.getvalue()
184+
185+
# Create ResultData with attachment containing the filtered data
186+
filtered_result_data = ResultData(
187+
data=None, # No JSON data
188+
external_links=None, # No external links
189+
attachment=arrow_stream_bytes, # Arrow data as attachment
190+
)
191+
192+
# Update manifest to reflect new row count
193+
manifest = result_set.manifest
194+
# Create a copy of the manifest to avoid modifying the original
195+
from copy import deepcopy
196+
197+
filtered_manifest = deepcopy(manifest)
198+
filtered_manifest.total_row_count = filtered_table.num_rows
199+
200+
# Create a new SeaResultSet with the filtered data
201+
filtered_result_set = SeaResultSet(
202+
connection=result_set.connection,
203+
execute_response=execute_response,
204+
sea_client=cast(SeaDatabricksClient, result_set.backend),
205+
result_data=filtered_result_data,
206+
manifest=filtered_manifest,
207+
buffer_size_bytes=result_set.buffer_size_bytes,
208+
arraysize=result_set.arraysize,
209+
)
210+
211+
return filtered_result_set
212+
91213
@staticmethod
92214
def filter_by_column_values(
93215
result_set: SeaResultSet,
@@ -150,7 +272,17 @@ def filter_tables_by_type(
150272
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
151273
)
152274

153-
# Table type is the 6th column (index 5)
154-
return ResultSetFilter.filter_by_column_values(
155-
result_set, 5, valid_types, case_sensitive=True
156-
)
275+
# Check if we have an Arrow table (cloud fetch) or JSON data
276+
from databricks.sql.utils import CloudFetchQueue, ArrowQueue
277+
278+
if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)):
279+
# For Arrow tables, we need to handle filtering differently
280+
return ResultSetFilter._filter_arrow_result_set(
281+
result_set, column_index=5, allowed_values=valid_types
282+
)
283+
else:
284+
# For JSON data, use the existing filter method
285+
# Table type is the 6th column (index 5)
286+
return ResultSetFilter.filter_by_column_values(
287+
result_set, 5, valid_types, case_sensitive=True
288+
)

tests/unit/test_filters.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -123,37 +123,30 @@ def test_filter_tables_by_type(self):
123123
# Case 1: Specific table types
124124
table_types = ["TABLE", "VIEW"]
125125

126-
with patch(
127-
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
128-
):
129-
with patch.object(
130-
ResultSetFilter, "filter_by_column_values"
131-
) as mock_filter:
132-
ResultSetFilter.filter_tables_by_type(
133-
self.mock_sea_result_set, table_types
134-
)
135-
args, kwargs = mock_filter.call_args
136-
self.assertEqual(args[0], self.mock_sea_result_set)
137-
self.assertEqual(args[1], 5) # Table type column index
138-
self.assertEqual(args[2], table_types)
139-
self.assertEqual(kwargs.get("case_sensitive"), True)
126+
# Mock results as JsonQueue (not CloudFetchQueue or ArrowQueue)
127+
from databricks.sql.backend.sea.queue import JsonQueue
128+
129+
self.mock_sea_result_set.results = JsonQueue([])
130+
131+
with patch.object(ResultSetFilter, "filter_by_column_values") as mock_filter:
132+
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
133+
args, kwargs = mock_filter.call_args
134+
self.assertEqual(args[0], self.mock_sea_result_set)
135+
self.assertEqual(args[1], 5) # Table type column index
136+
self.assertEqual(args[2], table_types)
137+
self.assertEqual(kwargs.get("case_sensitive"), True)
140138

141139
# Case 2: Default table types (None or empty list)
142-
with patch(
143-
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
144-
):
145-
with patch.object(
146-
ResultSetFilter, "filter_by_column_values"
147-
) as mock_filter:
148-
# Test with None
149-
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
150-
args, kwargs = mock_filter.call_args
151-
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
152-
153-
# Test with empty list
154-
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, [])
155-
args, kwargs = mock_filter.call_args
156-
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
140+
with patch.object(ResultSetFilter, "filter_by_column_values") as mock_filter:
141+
# Test with None
142+
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
143+
args, kwargs = mock_filter.call_args
144+
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
145+
146+
# Test with empty list
147+
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, [])
148+
args, kwargs = mock_filter.call_args
149+
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
157150

158151

159152
if __name__ == "__main__":

tests/unit/test_sea_backend.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,29 @@ def sea_client(self, mock_http_client):
5656
http_headers=http_headers,
5757
auth_provider=auth_provider,
5858
ssl_options=ssl_options,
59+
use_cloud_fetch=False,
60+
)
61+
62+
return client
63+
64+
@pytest.fixture
65+
def sea_client_cloud_fetch(self, mock_http_client):
66+
"""Create a SeaDatabricksClient instance with cloud fetch enabled."""
67+
server_hostname = "test-server.databricks.com"
68+
port = 443
69+
http_path = "/sql/warehouses/abc123"
70+
http_headers = [("header1", "value1"), ("header2", "value2")]
71+
auth_provider = AuthProvider()
72+
ssl_options = SSLOptions()
73+
74+
client = SeaDatabricksClient(
75+
server_hostname=server_hostname,
76+
port=port,
77+
http_path=http_path,
78+
http_headers=http_headers,
79+
auth_provider=auth_provider,
80+
ssl_options=ssl_options,
81+
use_cloud_fetch=True,
5982
)
6083

6184
return client
@@ -884,3 +907,74 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
884907
cursor=mock_cursor,
885908
)
886909
assert "Catalog name is required for get_columns" in str(excinfo.value)
910+
911+
def test_get_tables_with_cloud_fetch(
912+
self, sea_client_cloud_fetch, sea_session_id, mock_cursor
913+
):
914+
"""Test the get_tables method with cloud fetch enabled."""
915+
# Mock the execute_command method and ResultSetFilter
916+
mock_result_set = Mock()
917+
918+
with patch.object(
919+
sea_client_cloud_fetch, "execute_command", return_value=mock_result_set
920+
) as mock_execute:
921+
with patch(
922+
"databricks.sql.backend.sea.utils.filters.ResultSetFilter"
923+
) as mock_filter:
924+
mock_filter.filter_tables_by_type.return_value = mock_result_set
925+
926+
# Call get_tables
927+
result = sea_client_cloud_fetch.get_tables(
928+
session_id=sea_session_id,
929+
max_rows=100,
930+
max_bytes=1000,
931+
cursor=mock_cursor,
932+
catalog_name="test_catalog",
933+
)
934+
935+
# Verify execute_command was called with use_cloud_fetch=True
936+
mock_execute.assert_called_with(
937+
operation="SHOW TABLES IN CATALOG test_catalog",
938+
session_id=sea_session_id,
939+
max_rows=100,
940+
max_bytes=1000,
941+
lz4_compression=False,
942+
cursor=mock_cursor,
943+
use_cloud_fetch=True, # Should use True since client was created with use_cloud_fetch=True
944+
parameters=[],
945+
async_op=False,
946+
enforce_embedded_schema_correctness=False,
947+
)
948+
assert result == mock_result_set
949+
950+
def test_get_schemas_with_cloud_fetch(
951+
self, sea_client_cloud_fetch, sea_session_id, mock_cursor
952+
):
953+
"""Test the get_schemas method with cloud fetch enabled."""
954+
# Mock the execute_command method
955+
mock_result_set = Mock()
956+
with patch.object(
957+
sea_client_cloud_fetch, "execute_command", return_value=mock_result_set
958+
) as mock_execute:
959+
# Test with catalog name
960+
result = sea_client_cloud_fetch.get_schemas(
961+
session_id=sea_session_id,
962+
max_rows=100,
963+
max_bytes=1000,
964+
cursor=mock_cursor,
965+
catalog_name="test_catalog",
966+
)
967+
968+
mock_execute.assert_called_with(
969+
operation="SHOW SCHEMAS IN test_catalog",
970+
session_id=sea_session_id,
971+
max_rows=100,
972+
max_bytes=1000,
973+
lz4_compression=False,
974+
cursor=mock_cursor,
975+
use_cloud_fetch=True, # Should use True since client was created with use_cloud_fetch=True
976+
parameters=[],
977+
async_op=False,
978+
enforce_embedded_schema_correctness=False,
979+
)
980+
assert result == mock_result_set

0 commit comments

Comments
 (0)