Skip to content

Commit 4d40946

Browse files
Niall Egansusodapop
authored andcommitted
gRPC client API changes
Author: Niall Egan <niall.egan@databricks.com>
1 parent a5a541d commit 4d40946

File tree

3 files changed

+107
-48
lines changed

3 files changed

+107
-48
lines changed

cmdexec/clients/python/src/databricks/sql/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from .client import Connection
2-
3-
41
class _DBAPITypeObject(object):
52
def __init__(self, *values):
63
self.values = values
@@ -16,6 +13,10 @@ def __eq__(self, other):
1613
DATE = _DBAPITypeObject('date')
1714
ROWID = _DBAPITypeObject()
1815

16+
__version__ = "1.0.0"
17+
USER_AGENT_NAME = "PyDatabricksSqlConnector"
18+
1919

2020
def connect(**kwargs):
21+
from .client import Connection
2122
return Connection(**kwargs)

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,76 @@
1+
import base64
12
import logging
3+
import re
4+
import time
25
from typing import Dict, Tuple, List
3-
from collections import deque
46

57
import grpc
68
import pyarrow
79

8-
from .errors import OperationalError, InterfaceError, DatabaseError, Error
9-
from .api import messages_pb2
10-
from .api.sql_cmd_service_pb2_grpc import SqlCommandServiceStub
11-
12-
import time
10+
from databricks.sql.errors import OperationalError, InterfaceError, DatabaseError, Error
11+
from databricks.sql.api import messages_pb2
12+
from databricks.sql.api.sql_cmd_service_pb2_grpc import SqlCommandServiceStub
13+
from databricks.sql import USER_AGENT_NAME, __version__
1314

1415
logger = logging.getLogger(__name__)
1516

1617
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1718

1819

1920
class Connection:
20-
def __init__(self, **kwargs):
21-
try:
22-
self.host = kwargs["HOST"]
23-
self.port = kwargs["PORT"]
24-
except KeyError:
25-
raise InterfaceError("Please include arguments HOST and PORT in kwargs for Connection")
21+
def __init__(self, server_hostname, http_path, access_token, metadata=None, **kwargs):
22+
"""Connect to a Databricks SQL endpoint or a Databricks cluster.
23+
24+
:param server_hostname: Databricks instance host name.
25+
:param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
26+
or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
27+
:param access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
28+
:param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request
29+
"""
30+
31+
# Internal arguments in **kwargs:
32+
# _user_agent_entry
33+
# Tag to add to User-Agent header. For use by partners.
34+
# _username, _password
35+
# Username and password Basic authentication (no official support)
36+
# _enable_ssl
37+
# Connect over HTTP instead of HTTPS
38+
# _port
39+
# Which port to connect to
40+
# _skip_routing_headers:
41+
# Don't set routing headers if set to True (for use when connecting directly to server)
42+
43+
self.host = server_hostname
44+
self.port = kwargs.get("_port", 443)
45+
46+
if kwargs.get("_username") and kwargs.get("_password"):
47+
auth_credentials = "{username}:{password}".format(
48+
username=kwargs.get("_username"), password=kwargs.get("_password")).encode("UTF-8")
49+
auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8")
50+
authorization_header = "Basic {}".format(auth_credentials_base64)
51+
elif access_token:
52+
authorization_header = "Bearer {}".format(access_token)
53+
else:
54+
raise ValueError("No valid authentication settings.")
55+
56+
if not kwargs.get("_user_agent_entry"):
57+
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
58+
else:
59+
useragent_header = "{}/{} ({})".format(USER_AGENT_NAME, __version__,
60+
kwargs.get("_user_agent_entry"))
61+
62+
base_headers = [("Authorization", authorization_header),
63+
("X-Databricks-Sqlgateway-CommandService-Mode", "grpc-thrift"),
64+
("User-Agent", useragent_header)]
65+
66+
if not kwargs.get("_skip_routing_headers"):
67+
base_headers.append(self._http_path_to_routing_header(http_path))
68+
69+
self.base_client = CmdExecBaseHttpClient(
70+
self.host,
71+
self.port, (metadata or []) + base_headers,
72+
enable_ssl=kwargs.get("_enable_ssl", True))
2673

27-
self.base_client = CmdExecBaseHttpClient(self.host, self.port, kwargs.get("metadata", []))
2874
open_session_request = messages_pb2.OpenSessionRequest(
2975
configuration={},
3076
client_session_id=None,
@@ -44,6 +90,18 @@ def __enter__(self):
4490
def __exit__(self, exc_type, exc_value, traceback):
4591
self.close()
4692

93+
def _http_path_to_routing_header(self, http_path):
94+
cluster_re = r'/?sql/protocolv1/o/\d+/(\d+-\d+-[a-zA-Z0-9]+)'
95+
endpoint_re = r'/?sql/.*/endpoints/([a-f0-9]+)'
96+
cluster_id_match = re.search(cluster_re, http_path)
97+
endpoint_re_match = re.search(endpoint_re, http_path)
98+
if cluster_id_match:
99+
return "X-Databricks-Cluster-Id", cluster_id_match.groups()[0]
100+
elif endpoint_re_match:
101+
return "X-Databricks-Sql-Endpoint-Id", endpoint_re_match.groups()[0]
102+
else:
103+
raise ValueError("Please provide a valid http_path")
104+
47105
def cursor(self, buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
48106
if not self.open:
49107
raise Error("Cannot create cursor from closed connection")
@@ -62,10 +120,13 @@ def close(self):
62120

63121

64122
class Cursor:
65-
def __init__(self, connection, result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
123+
def __init__(self,
124+
connection,
125+
result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES,
126+
arraysize=10000):
66127
self.connection = connection
67128
self.rowcount = -1
68-
self.arraysize = 1
129+
self.arraysize = arraysize
69130
self.buffer_size_bytes = result_buffer_size_bytes
70131
self.active_result_set = None
71132
# Note that Cursor closed => active result set closed, but not vice versa
@@ -148,6 +209,7 @@ def execute(self, operation, query_params=None, metadata=None):
148209
conf_overlay=None,
149210
row_limit=None,
150211
result_options=messages_pb2.CommandResultOptions(
212+
max_rows=self.arraysize,
151213
max_bytes=self.buffer_size_bytes,
152214
include_metadata=True,
153215
))
@@ -208,7 +270,8 @@ def __init__(self,
208270
arrow_ipc_stream=None,
209271
num_valid_rows=None,
210272
schema_message=None,
211-
result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
273+
result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES,
274+
arraysize=10000):
212275
self.connection = connection
213276
self.command_id = command_id
214277
self.status = status
@@ -217,6 +280,7 @@ def __init__(self,
217280
self.buffer_size_bytes = result_buffer_size_bytes
218281
self._row_index = 0
219282
self.description = None
283+
self.arraysize = arraysize
220284

221285
assert (self.status not in [messages_pb2.PENDING, messages_pb2.RUNNING])
222286

@@ -243,6 +307,7 @@ def _fetch_and_deserialize_results(self):
243307
id=self.command_id,
244308
options=messages_pb2.CommandResultOptions(
245309
max_bytes=self.buffer_size_bytes,
310+
max_rows=self.arraysize,
246311
row_offset=self._row_index,
247312
include_metadata=True,
248313
))
@@ -392,11 +457,19 @@ class CmdExecBaseHttpClient:
392457
A thin wrapper around a gRPC channel that takes cares of headers etc.
393458
"""
394459

395-
def __init__(self, host: str, port: int, http_headers: List[Tuple[str, str]]):
460+
def __init__(self, host: str, port: int, http_headers: List[Tuple[str, str]], enable_ssl=True):
396461
self.host_url = host + ":" + str(port)
397462
self.http_headers = [(k.lower(), v) for (k, v) in http_headers]
398-
self.channel = grpc.insecure_channel(
399-
self.host_url, options=[('grpc.max_receive_message_length', -1)])
463+
if enable_ssl:
464+
self.channel = grpc.secure_channel(
465+
self.host_url,
466+
options=[('grpc.max_receive_message_length', -1)],
467+
credentials=grpc.ssl_channel_credentials())
468+
else:
469+
self.channel = grpc.insecure_channel(
470+
self.host_url,
471+
options=[('grpc.max_receive_message_length', -1)],
472+
)
400473
self.stub = SqlCommandServiceStub(self.channel)
401474

402475
def make_request(self, method, request):

cmdexec/clients/python/tests/tests.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,21 @@ class SimpleTests(unittest.TestCase):
1818
"""
1919

2020
PACKAGE_NAME = "databricks.sql"
21-
22-
def test_missing_params_throws_interface_exception(self):
23-
bad_connection_args = [
24-
{
25-
"HOST": 'host'
26-
},
27-
{
28-
"host": 'host',
29-
"PORT": 1
30-
},
31-
{},
32-
]
33-
34-
for args in bad_connection_args:
35-
with self.assertRaises(InterfaceError) as ie:
36-
databricks.sql.connect(**args)
37-
self.assertIn("HOST and PORT", ie.message)
21+
DUMMY_CONNECTION_ARGS = {
22+
"server_hostname": "foo",
23+
"http_path": None,
24+
"access_token": "tok",
25+
"_skip_routing_headers": True,
26+
}
3827

3928
@patch("%s.client.CmdExecBaseHttpClient" % PACKAGE_NAME)
4029
def test_close_uses_the_correct_session_id(self, mock_client_class):
4130
instance = mock_client_class.return_value
4231
mock_response = MagicMock()
43-
mock_response.id = b'\x22'
4432
instance.make_request.return_value = mock_response
45-
good_connection_args = {"HOST": 1, "PORT": 1}
33+
mock_response.id = b'\x22'
4634

47-
connection = databricks.sql.connect(**good_connection_args)
35+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
4836
connection.close()
4937

5038
# Check the close session request has an id of x22
@@ -67,8 +55,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class, mock_cl
6755
mock_result_set = Mock()
6856
mock_result_set_class.return_value = mock_result_set
6957

70-
good_connection_args = {"HOST": 1, "PORT": 1}
71-
connection = databricks.sql.connect(**good_connection_args)
58+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
7259
cursor = connection.cursor()
7360
cursor.execute("SELECT 1;")
7461
connection.close()
@@ -82,8 +69,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
8269
mock_response = MagicMock()
8370
mock_response.id = b'\x22'
8471
instance.make_request.return_value = mock_response
85-
good_connection_args = {"HOST": 1, "PORT": 1}
86-
connection = databricks.sql.connect(**good_connection_args)
72+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
8773
self.assertTrue(connection.open)
8874
connection.close()
8975
self.assertFalse(connection.open)
@@ -203,10 +189,9 @@ def test_context_manager_closes_connection(self, mock_client_class):
203189
mock_response = MagicMock()
204190
mock_response.id = b'\x22'
205191
instance.make_request.return_value = mock_response
206-
good_connection_args = {"HOST": 1, "PORT": 1}
207192
mock_close = Mock()
208193

209-
with databricks.sql.connect(**good_connection_args) as connection:
194+
with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection:
210195
connection.close = mock_close
211196
mock_close.assert_called_once_with()
212197

0 commit comments

Comments
 (0)