Skip to content

Commit a5decfc

Browse files
introduce session open and close with SEA client
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent cf8a629 commit a5decfc

File tree

5 files changed

+732
-10
lines changed

5 files changed

+732
-10
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
import sys
3+
import logging
4+
from databricks.sql.client import Connection
5+
6+
logging.basicConfig(level=logging.DEBUG)
7+
logger = logging.getLogger(__name__)
8+
9+
def test_sea_session():
10+
"""
11+
Test opening and closing a SEA session using the connector.
12+
13+
This function connects to a Databricks SQL endpoint using the SEA backend,
14+
opens a session, and then closes it.
15+
16+
Required environment variables:
17+
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18+
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19+
- DATABRICKS_TOKEN: Personal access token for authentication
20+
"""
21+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
22+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
23+
access_token = os.environ.get("DATABRICKS_TOKEN")
24+
catalog = os.environ.get("DATABRICKS_CATALOG")
25+
26+
if not all([server_hostname, http_path, access_token]):
27+
logger.error("Missing required environment variables.")
28+
logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.")
29+
sys.exit(1)
30+
31+
logger.info(f"Connecting to {server_hostname}")
32+
logger.info(f"HTTP Path: {http_path}")
33+
if catalog:
34+
logger.info(f"Using catalog: {catalog}")
35+
36+
try:
37+
logger.info("Creating connection with SEA backend...")
38+
connection = Connection(
39+
server_hostname=server_hostname,
40+
http_path=http_path,
41+
access_token=access_token,
42+
catalog=catalog,
43+
schema="default",
44+
use_sea=True,
45+
user_agent_entry="SEA-Test-Client" # add custom user agent
46+
)
47+
48+
logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}")
49+
logger.info(f"backend type: {type(connection.session.backend)}")
50+
51+
# Close the connection
52+
logger.info("Closing the SEA session...")
53+
connection.close()
54+
logger.info("Successfully closed SEA session")
55+
56+
except Exception as e:
57+
logger.error(f"Error testing SEA session: {str(e)}")
58+
import traceback
59+
logger.error(traceback.format_exc())
60+
sys.exit(1)
61+
62+
logger.info("SEA session test completed successfully")
63+
64+
if __name__ == "__main__":
65+
test_sea_session()
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import logging
2+
import uuid
3+
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from databricks.sql.client import Cursor
7+
8+
from databricks.sql.backend.databricks_client import DatabricksClient
9+
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
10+
from databricks.sql.exc import Error, NotSupportedError
11+
from databricks.sql.backend.utils.http_client import CustomHttpClient
12+
from databricks.sql.thrift_api.TCLIService import ttypes
13+
from databricks.sql.types import SSLOptions
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class SeaDatabricksClient(DatabricksClient):
19+
"""
20+
Statement Execution API (SEA) implementation of the DatabricksClient interface.
21+
22+
This implementation provides session management functionality for SEA,
23+
while other operations raise NotImplementedError.
24+
"""
25+
26+
# SEA API paths
27+
BASE_PATH = "/api/2.0/sql/"
28+
SESSION_PATH = BASE_PATH + "sessions"
29+
SESSION_PATH_WITH_ID = SESSION_PATH + "/{}"
30+
STATEMENT_PATH = BASE_PATH + "statements"
31+
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
32+
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
33+
34+
def __init__(
35+
self,
36+
server_hostname: str,
37+
port: int,
38+
http_path: str,
39+
http_headers: List[Tuple[str, str]],
40+
auth_provider,
41+
ssl_options: SSLOptions,
42+
staging_allowed_local_path: Union[None, str, List[str]] = None,
43+
**kwargs,
44+
):
45+
"""
46+
Initialize the SEA backend client.
47+
48+
Args:
49+
server_hostname: Hostname of the Databricks server
50+
port: Port number for the connection
51+
http_path: HTTP path for the connection
52+
http_headers: List of HTTP headers to include in requests
53+
auth_provider: Authentication provider
54+
ssl_options: SSL configuration options
55+
staging_allowed_local_path: Allowed local paths for staging operations
56+
**kwargs: Additional keyword arguments
57+
"""
58+
logger.debug(
59+
"SEADatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)",
60+
server_hostname,
61+
port,
62+
http_path,
63+
)
64+
65+
self._staging_allowed_local_path = staging_allowed_local_path
66+
self._ssl_options = ssl_options
67+
self._max_download_threads = kwargs.get("max_download_threads", 10)
68+
69+
# Extract warehouse ID from http_path
70+
self.warehouse_id = self._extract_warehouse_id(http_path)
71+
72+
# Initialize HTTP client
73+
self.http_client = CustomHttpClient(
74+
server_hostname=server_hostname,
75+
port=port,
76+
http_path=http_path,
77+
http_headers=http_headers,
78+
auth_provider=auth_provider,
79+
ssl_options=ssl_options,
80+
**kwargs,
81+
)
82+
83+
def _extract_warehouse_id(self, http_path: str) -> str:
84+
"""
85+
Extract the warehouse ID from the HTTP path.
86+
87+
The warehouse ID is expected to be the last segment of the path when the
88+
second-to-last segment is either 'warehouses' or 'endpoints'.
89+
This matches the JDBC implementation which supports both formats.
90+
91+
Args:
92+
http_path: The HTTP path from which to extract the warehouse ID
93+
94+
Returns:
95+
The extracted warehouse ID
96+
97+
Raises:
98+
Error: If the warehouse ID cannot be extracted from the path
99+
"""
100+
path_parts = http_path.strip("/").split("/")
101+
warehouse_id = None
102+
103+
if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]:
104+
warehouse_id = path_parts[-1]
105+
logger.debug(
106+
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}"
107+
)
108+
109+
if not warehouse_id:
110+
error_message = (
111+
f"Could not extract warehouse ID from http_path: {http_path}. "
112+
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
113+
f"/path/to/endpoints/{{warehouse_id}}"
114+
)
115+
logger.error(error_message)
116+
raise ValueError(error_message)
117+
118+
return warehouse_id
119+
120+
@property
121+
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
122+
"""Get the allowed local paths for staging operations."""
123+
return self._staging_allowed_local_path
124+
125+
@property
126+
def ssl_options(self) -> SSLOptions:
127+
"""Get the SSL options for this client."""
128+
return self._ssl_options
129+
130+
@property
131+
def max_download_threads(self) -> int:
132+
"""Get the maximum number of download threads for cloud fetch operations."""
133+
return self._max_download_threads
134+
135+
def open_session(
136+
self,
137+
session_configuration: Optional[Dict[str, str]],
138+
catalog: Optional[str],
139+
schema: Optional[str],
140+
) -> SessionId:
141+
"""
142+
Opens a new session with the Databricks SQL service using SEA.
143+
144+
Args:
145+
session_configuration: Optional dictionary of configuration parameters for the session
146+
catalog: Optional catalog name to use as the initial catalog for the session
147+
schema: Optional schema name to use as the initial schema for the session
148+
149+
Returns:
150+
SessionId: A session identifier object that can be used for subsequent operations
151+
152+
Raises:
153+
Error: If the session configuration is invalid
154+
OperationalError: If there's an error establishing the session
155+
"""
156+
logger.debug(
157+
"SEADatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)",
158+
session_configuration,
159+
catalog,
160+
schema,
161+
)
162+
163+
request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id}
164+
if session_configuration:
165+
request_data["session_confs"] = session_configuration
166+
if catalog:
167+
request_data["catalog"] = catalog
168+
if schema:
169+
request_data["schema"] = schema
170+
171+
response = self.http_client._make_request(
172+
method="POST", path=self.SESSION_PATH, data=request_data
173+
)
174+
175+
session_id = response.get("session_id")
176+
if not session_id:
177+
raise Error("Failed to create session: No session ID returned")
178+
179+
return SessionId.from_sea_session_id(session_id)
180+
181+
def close_session(self, session_id: SessionId) -> None:
182+
"""
183+
Closes an existing session with the Databricks SQL service.
184+
185+
Args:
186+
session_id: The session identifier returned by open_session()
187+
188+
Raises:
189+
ValueError: If the session ID is invalid
190+
OperationalError: If there's an error closing the session
191+
"""
192+
logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id)
193+
194+
if session_id.backend_type != BackendType.SEA:
195+
raise ValueError("Not a valid SEA session ID")
196+
sea_session_id = session_id.to_sea_session_id()
197+
198+
request_data = {"warehouse_id": self.warehouse_id}
199+
200+
self.http_client._make_request(
201+
method="DELETE",
202+
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
203+
data=request_data,
204+
)
205+
206+
# == Not Implemented Operations ==
207+
# These methods will be implemented in future iterations
208+
209+
def execute_command(
210+
self,
211+
operation: str,
212+
session_id: SessionId,
213+
max_rows: int,
214+
max_bytes: int,
215+
lz4_compression: bool,
216+
cursor: "Cursor",
217+
use_cloud_fetch: bool,
218+
parameters: List[ttypes.TSparkParameter],
219+
async_op: bool,
220+
enforce_embedded_schema_correctness: bool,
221+
):
222+
"""Not implemented yet."""
223+
raise NotSupportedError(
224+
"execute_command is not yet implemented for SEA backend"
225+
)
226+
227+
def cancel_command(self, command_id: CommandId) -> None:
228+
"""Not implemented yet."""
229+
raise NotSupportedError("cancel_command is not yet implemented for SEA backend")
230+
231+
def close_command(self, command_id: CommandId) -> None:
232+
"""Not implemented yet."""
233+
raise NotSupportedError("close_command is not yet implemented for SEA backend")
234+
235+
def get_query_state(self, command_id: CommandId) -> CommandState:
236+
"""Not implemented yet."""
237+
raise NotSupportedError(
238+
"get_query_state is not yet implemented for SEA backend"
239+
)
240+
241+
def get_execution_result(
242+
self,
243+
command_id: CommandId,
244+
cursor: "Cursor",
245+
):
246+
"""Not implemented yet."""
247+
raise NotSupportedError(
248+
"get_execution_result is not yet implemented for SEA backend"
249+
)
250+
251+
# == Metadata Operations ==
252+
253+
def get_catalogs(
254+
self,
255+
session_id: SessionId,
256+
max_rows: int,
257+
max_bytes: int,
258+
cursor: "Cursor",
259+
):
260+
"""Not implemented yet."""
261+
raise NotSupportedError("get_catalogs is not yet implemented for SEA backend")
262+
263+
def get_schemas(
264+
self,
265+
session_id: SessionId,
266+
max_rows: int,
267+
max_bytes: int,
268+
cursor: "Cursor",
269+
catalog_name: Optional[str] = None,
270+
schema_name: Optional[str] = None,
271+
):
272+
"""Not implemented yet."""
273+
raise NotSupportedError("get_schemas is not yet implemented for SEA backend")
274+
275+
def get_tables(
276+
self,
277+
session_id: SessionId,
278+
max_rows: int,
279+
max_bytes: int,
280+
cursor: "Cursor",
281+
catalog_name: Optional[str] = None,
282+
schema_name: Optional[str] = None,
283+
table_name: Optional[str] = None,
284+
table_types: Optional[List[str]] = None,
285+
):
286+
"""Not implemented yet."""
287+
raise NotSupportedError("get_tables is not yet implemented for SEA backend")
288+
289+
def get_columns(
290+
self,
291+
session_id: SessionId,
292+
max_rows: int,
293+
max_bytes: int,
294+
cursor: "Cursor",
295+
catalog_name: Optional[str] = None,
296+
schema_name: Optional[str] = None,
297+
table_name: Optional[str] = None,
298+
column_name: Optional[str] = None,
299+
):
300+
"""Not implemented yet."""
301+
raise NotSupportedError("get_columns is not yet implemented for SEA backend")

0 commit comments

Comments
 (0)