1+ import base64
12import logging
3+ import re
4+ import time
25from typing import Dict , Tuple , List
3- from collections import deque
46
57import grpc
68import pyarrow
79
8- from .errors import OperationalError , InterfaceError , DatabaseError , Error
9- from .api import messages_pb2
10- from .api .sql_cmd_service_pb2_grpc import SqlCommandServiceStub
11-
12- import time
10+ from databricks .sql .errors import OperationalError , InterfaceError , DatabaseError , Error
11+ from databricks .sql .api import messages_pb2
12+ from databricks .sql .api .sql_cmd_service_pb2_grpc import SqlCommandServiceStub
13+ from databricks .sql import USER_AGENT_NAME , __version__
1314
1415logger = logging .getLogger (__name__ )
1516
1617DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1718
1819
1920class Connection :
20- def __init__ (self , ** kwargs ):
21- try :
22- self .host = kwargs ["HOST" ]
23- self .port = kwargs ["PORT" ]
24- except KeyError :
25- raise InterfaceError ("Please include arguments HOST and PORT in kwargs for Connection" )
21+ def __init__ (self , server_hostname , http_path , access_token , metadata = None , ** kwargs ):
22+ """Connect to a Databricks SQL endpoint or a Databricks cluster.
23+
24+ :param server_hostname: Databricks instance host name.
25+ :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
26+ or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
27+ :param access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
28+ :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request
29+ """
30+
31+ # Internal arguments in **kwargs:
32+ # _user_agent_entry
33+ # Tag to add to User-Agent header. For use by partners.
34+ # _username, _password
35+ # Username and password Basic authentication (no official support)
36+ # _enable_ssl
37+ # Connect over HTTP instead of HTTPS
38+ # _port
39+ # Which port to connect to
40+ # _skip_routing_headers:
41+ # Don't set routing headers if set to True (for use when connecting directly to server)
42+
43+ self .host = server_hostname
44+ self .port = kwargs .get ("_port" , 443 )
45+
46+ if kwargs .get ("_username" ) and kwargs .get ("_password" ):
47+ auth_credentials = "{username}:{password}" .format (
48+ username = kwargs .get ("_username" ), password = kwargs .get ("_password" )).encode ("UTF-8" )
49+ auth_credentials_base64 = base64 .standard_b64encode (auth_credentials ).decode ("UTF-8" )
50+ authorization_header = "Basic {}" .format (auth_credentials_base64 )
51+ elif access_token :
52+ authorization_header = "Bearer {}" .format (access_token )
53+ else :
54+ raise ValueError ("No valid authentication settings." )
55+
56+ if not kwargs .get ("_user_agent_entry" ):
57+ useragent_header = "{}/{}" .format (USER_AGENT_NAME , __version__ )
58+ else :
59+ useragent_header = "{}/{} ({})" .format (USER_AGENT_NAME , __version__ ,
60+ kwargs .get ("_user_agent_entry" ))
61+
62+ base_headers = [("Authorization" , authorization_header ),
63+ ("X-Databricks-Sqlgateway-CommandService-Mode" , "grpc-thrift" ),
64+ ("User-Agent" , useragent_header )]
65+
66+ if not kwargs .get ("_skip_routing_headers" ):
67+ base_headers .append (self ._http_path_to_routing_header (http_path ))
68+
69+ self .base_client = CmdExecBaseHttpClient (
70+ self .host ,
71+ self .port , (metadata or []) + base_headers ,
72+ enable_ssl = kwargs .get ("_enable_ssl" , True ))
2673
27- self .base_client = CmdExecBaseHttpClient (self .host , self .port , kwargs .get ("metadata" , []))
2874 open_session_request = messages_pb2 .OpenSessionRequest (
2975 configuration = {},
3076 client_session_id = None ,
@@ -44,6 +90,18 @@ def __enter__(self):
4490 def __exit__ (self , exc_type , exc_value , traceback ):
4591 self .close ()
4692
93+ def _http_path_to_routing_header (self , http_path ):
94+ cluster_re = r'/?sql/protocolv1/o/\d+/(\d+-\d+-[a-zA-Z0-9]+)'
95+ endpoint_re = r'/?sql/.*/endpoints/([a-f0-9]+)'
96+ cluster_id_match = re .search (cluster_re , http_path )
97+ endpoint_re_match = re .search (endpoint_re , http_path )
98+ if cluster_id_match :
99+ return "X-Databricks-Cluster-Id" , cluster_id_match .groups ()[0 ]
100+ elif endpoint_re_match :
101+ return "X-Databricks-Sql-Endpoint-Id" , endpoint_re_match .groups ()[0 ]
102+ else :
103+ raise ValueError ("Please provide a valid http_path" )
104+
47105 def cursor (self , buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
48106 if not self .open :
49107 raise Error ("Cannot create cursor from closed connection" )
@@ -62,10 +120,13 @@ def close(self):
62120
63121
64122class Cursor :
65- def __init__ (self , connection , result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
123+ def __init__ (self ,
124+ connection ,
125+ result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ,
126+ arraysize = 10000 ):
66127 self .connection = connection
67128 self .rowcount = - 1
68- self .arraysize = 1
129+ self .arraysize = arraysize
69130 self .buffer_size_bytes = result_buffer_size_bytes
70131 self .active_result_set = None
71132 # Note that Cursor closed => active result set closed, but not vice versa
@@ -148,6 +209,7 @@ def execute(self, operation, query_params=None, metadata=None):
148209 conf_overlay = None ,
149210 row_limit = None ,
150211 result_options = messages_pb2 .CommandResultOptions (
212+ max_rows = self .arraysize ,
151213 max_bytes = self .buffer_size_bytes ,
152214 include_metadata = True ,
153215 ))
@@ -208,7 +270,8 @@ def __init__(self,
208270 arrow_ipc_stream = None ,
209271 num_valid_rows = None ,
210272 schema_message = None ,
211- result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
273+ result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ,
274+ arraysize = 10000 ):
212275 self .connection = connection
213276 self .command_id = command_id
214277 self .status = status
@@ -217,6 +280,7 @@ def __init__(self,
217280 self .buffer_size_bytes = result_buffer_size_bytes
218281 self ._row_index = 0
219282 self .description = None
283+ self .arraysize = arraysize
220284
221285 assert (self .status not in [messages_pb2 .PENDING , messages_pb2 .RUNNING ])
222286
@@ -243,6 +307,7 @@ def _fetch_and_deserialize_results(self):
243307 id = self .command_id ,
244308 options = messages_pb2 .CommandResultOptions (
245309 max_bytes = self .buffer_size_bytes ,
310+ max_rows = self .arraysize ,
246311 row_offset = self ._row_index ,
247312 include_metadata = True ,
248313 ))
@@ -392,11 +457,19 @@ class CmdExecBaseHttpClient:
392457 A thin wrapper around a gRPC channel that takes cares of headers etc.
393458 """
394459
395- def __init__ (self , host : str , port : int , http_headers : List [Tuple [str , str ]]):
460+ def __init__ (self , host : str , port : int , http_headers : List [Tuple [str , str ]], enable_ssl = True ):
396461 self .host_url = host + ":" + str (port )
397462 self .http_headers = [(k .lower (), v ) for (k , v ) in http_headers ]
398- self .channel = grpc .insecure_channel (
399- self .host_url , options = [('grpc.max_receive_message_length' , - 1 )])
463+ if enable_ssl :
464+ self .channel = grpc .secure_channel (
465+ self .host_url ,
466+ options = [('grpc.max_receive_message_length' , - 1 )],
467+ credentials = grpc .ssl_channel_credentials ())
468+ else :
469+ self .channel = grpc .insecure_channel (
470+ self .host_url ,
471+ options = [('grpc.max_receive_message_length' , - 1 )],
472+ )
400473 self .stub = SqlCommandServiceStub (self .channel )
401474
402475 def make_request (self , method , request ):
0 commit comments