Skip to content

Commit f8a4799

Browse files
NiallEgansusodapop
authored andcommitted
API tweaks + query params
This PR adds various little tweaks to make the v2 API compliant with the V1 API. The one API difference is that in V2 the cursor does not have a `next` method. That is because in Python3 the proper way to do iteration is to create an itterable using `i = iter(cursor)` and then call `i.next()` * New unit tests * Using Python inspect module to find differences in API
1 parent 10d5dc5 commit f8a4799

File tree

4 files changed

+250
-38
lines changed

4 files changed

+250
-38
lines changed
Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import datetime
2+
13
from databricks.sql.exc import *
24

5+
# PEP 249 module globals
6+
apilevel = '2.0'
7+
threadsafety = 1 # Threads may share the module, but not connections.
8+
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
9+
310

4-
class _DBAPITypeObject(object):
11+
class DBAPITypeObject(object):
512
def __init__(self, *values):
613
self.values = values
714

@@ -12,18 +19,30 @@ def __repr__(self):
1219
return "DBAPITypeObject({})".format(self.values)
1320

1421

15-
STRING = _DBAPITypeObject('string')
16-
BINARY = _DBAPITypeObject('binary')
17-
NUMBER = _DBAPITypeObject('boolean', 'tinyint', 'smallint', 'int', 'bigint', 'float', 'double',
18-
'decimal')
19-
DATETIME = _DBAPITypeObject('timestamp')
20-
DATE = _DBAPITypeObject('date')
21-
ROWID = _DBAPITypeObject()
22+
STRING = DBAPITypeObject('string')
23+
BINARY = DBAPITypeObject('binary')
24+
NUMBER = DBAPITypeObject('boolean', 'tinyint', 'smallint', 'int', 'bigint', 'float', 'double',
25+
'decimal')
26+
DATETIME = DBAPITypeObject('timestamp')
27+
DATE = DBAPITypeObject('date')
28+
ROWID = DBAPITypeObject()
2229

2330
__version__ = "2.0.0rc2"
2431
USER_AGENT_NAME = "PyDatabricksSqlConnector"
2532

33+
# These two functions are pyhive legacy
34+
Date = datetime.date
35+
Timestamp = datetime.datetime
36+
37+
38+
def DateFromTicks(ticks):
39+
return Date(*time.localtime(ticks)[:3])
40+
41+
42+
def TimestampFromTicks(ticks):
43+
return Timestamp(*time.localtime(ticks)[:6])
44+
2645

27-
def connect(**kwargs):
46+
def connect(server_hostname, http_path, access_token, **kwargs):
2847
from .client import Connection
29-
return Connection(**kwargs)
48+
return Connection(server_hostname, http_path, access_token, **kwargs)

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from databricks.sql import USER_AGENT_NAME, __version__
1111
from databricks.sql import *
1212
from databricks.sql.thrift_backend import ThriftBackend
13-
from databricks.sql.utils import ExecuteResponse
13+
from databricks.sql.utils import ExecuteResponse, ParamEscaper
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -129,6 +129,13 @@ def close(self) -> None:
129129
for cursor in self._cursors:
130130
cursor.close()
131131

132+
def commit(self):
133+
"""No-op because Databricks does not support transactions"""
134+
pass
135+
136+
def rollback(self):
137+
raise NotSupportedError("Transactions are not supported on Databricks")
138+
132139

133140
class Cursor:
134141
def __init__(self,
@@ -144,7 +151,7 @@ def __init__(self,
144151
visible by other cursors or connections.
145152
"""
146153
self.connection = connection
147-
self.rowcount = -1
154+
self.rowcount = -1 # Return -1 as this is not supported
148155
self.buffer_size_bytes = result_buffer_size_bytes
149156
self.active_result_set = None
150157
self.arraysize = arraysize
@@ -153,6 +160,8 @@ def __init__(self,
153160
self.executing_command_id = None
154161
self.thrift_backend = thrift_backend
155162
self.active_op_handle = None
163+
self.escaper = ParamEscaper()
164+
self.lastrowid = None
156165

157166
def __enter__(self):
158167
return self
@@ -178,26 +187,23 @@ def _check_not_closed(self):
178187
if not self.open:
179188
raise Error("Attempting operation on closed cursor")
180189

181-
def execute(self,
182-
operation: str,
183-
query_params: Optional[Dict[str, str]] = None,
184-
metadata: Optional[Dict[str, str]] = None) -> "Cursor":
190+
def execute(self, operation: str, parameters: Optional[Dict[str, str]] = None) -> "Cursor":
185191
"""
186192
Execute a query and wait for execution to complete.
187-
193+
Parameters should be given in extended param format style: %(...)<s|d|f>.
194+
For example:
195+
operation = "SELECT * FROM %(table_name)s"
196+
parameters = {"table_name": "my_table_name"}
197+
Will result in the query "SELECT * FROM 'my_table_name' being sent to the server
188198
:returns self
189199
"""
190-
if query_params is None:
191-
sql = operation
192-
else:
193-
# TODO(https://databricks.atlassian.net/browse/SC-88829) before public release
194-
logger.error("query param substitution currently un-implemented")
195-
sql = operation
200+
if parameters is not None:
201+
operation = operation % self.escaper.escape_args(parameters)
196202

197203
self._check_not_closed()
198204
self._close_and_clear_active_result_set()
199205
execute_response = self.thrift_backend.execute_command(
200-
operation=sql,
206+
operation=operation,
201207
session_handle=self.connection._session_handle,
202208
max_rows=self.arraysize,
203209
max_bytes=self.buffer_size_bytes,
@@ -206,6 +212,19 @@ def execute(self,
206212
self.buffer_size_bytes, self.arraysize)
207213
return self
208214

215+
def executemany(self, operation, seq_of_parameters):
216+
"""
217+
Prepare a database operation (query or command) and then execute it against all parameter
218+
sequences or mappings found in the sequence ``seq_of_parameters``.
219+
220+
Only the final result set is retained.
221+
222+
:returns self
223+
"""
224+
for parameters in seq_of_parameters:
225+
self.execute(operation, parameters)
226+
return self
227+
209228
def catalogs(self) -> "Cursor":
210229
"""
211230
Get all available catalogs.
@@ -327,7 +346,7 @@ def fetchone(self) -> Tuple:
327346
else:
328347
raise Error("There is no active result set")
329348

330-
def fetchmany(self, n_rows: int) -> List[Tuple]:
349+
def fetchmany(self, size: int) -> List[Tuple]:
331350
"""
332351
Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
333352
list of tuples).
@@ -345,7 +364,7 @@ def fetchmany(self, n_rows: int) -> List[Tuple]:
345364
"""
346365
self._check_not_closed()
347366
if self.active_result_set:
348-
return self.active_result_set.fetchmany(n_rows)
367+
return self.active_result_set.fetchmany(size)
349368
else:
350369
raise Error("There is no active result set")
351370

@@ -356,10 +375,10 @@ def fetchall_arrow(self):
356375
else:
357376
raise Error("There is no active result set")
358377

359-
def fetchmany_arrow(self, n_rows):
378+
def fetchmany_arrow(self, size):
360379
self._check_not_closed()
361380
if self.active_result_set:
362-
return self.active_result_set.fetchmany_arrow(n_rows)
381+
return self.active_result_set.fetchmany_arrow(size)
363382
else:
364383
raise Error("There is no active result set")
365384

@@ -407,6 +426,24 @@ def description(self) -> Optional[List[Tuple]]:
407426
else:
408427
return None
409428

429+
@property
430+
def rownumber(self):
431+
"""This read-only attribute should provide the current 0-based index of the cursor in the
432+
result set.
433+
434+
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
435+
operation will fetch the row indexed by ``rownumber`` in that sequence.
436+
"""
437+
return self.active_result_set.rownumber if self.active_result_set else 0
438+
439+
def setinputsizes(self, sizes):
440+
"""Does nothing by default"""
441+
pass
442+
443+
def setoutputsize(self, size, column=None):
444+
"""Does nothing by default"""
445+
pass
446+
410447

411448
class ResultSet:
412449
def __init__(self,
@@ -468,16 +505,20 @@ def _convert_arrow_table(self, table):
468505
for row_index in range(n_rows)]
469506
return list_repr
470507

471-
def fetchmany_arrow(self, n_rows: int) -> pyarrow.Table:
508+
@property
509+
def rownumber(self):
510+
return self._next_row_index
511+
512+
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
472513
"""
473514
Fetch the next set of rows of a query result, returning a PyArrow table.
474515
475516
An empty sequence is returned when no more rows are available.
476517
"""
477-
if n_rows < 0:
478-
raise ValueError("n_rows argument for fetchmany is %s but must be >= 0", n_rows)
479-
results = self.results.next_n_rows(n_rows)
480-
n_remaining_rows = n_rows - results.num_rows
518+
if size < 0:
519+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
520+
results = self.results.next_n_rows(size)
521+
n_remaining_rows = size - results.num_rows
481522
self._next_row_index += results.num_rows
482523

483524
while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows:
@@ -519,13 +560,13 @@ def fetchall(self) -> List[Tuple]:
519560
"""
520561
return self._convert_arrow_table(self.fetchall_arrow())
521562

522-
def fetchmany(self, n_rows: int) -> List[Tuple]:
563+
def fetchmany(self, size: int) -> List[Tuple]:
523564
"""
524565
Fetch the next set of rows of a query result, returning a list of lists.
525566
526567
An empty sequence is returned when no more rows are available.
527568
"""
528-
return self._convert_arrow_table(self.fetchmany_arrow(n_rows))
569+
return self._convert_arrow_table(self.fetchmany_arrow(size))
529570

530571
def close(self) -> None:
531572
"""

cmdexec/clients/python/src/databricks/sql/utils.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from collections import namedtuple, OrderedDict
1+
from collections import namedtuple, OrderedDict, Iterable
2+
import datetime
23
from enum import Enum
34

45
import pyarrow
@@ -110,3 +111,58 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
110111
user_friendly_error_message, elapsed)
111112

112113
return user_friendly_error_message
114+
115+
116+
# Taken from PyHive
117+
class ParamEscaper:
118+
_DATE_FORMAT = "%Y-%m-%d"
119+
_TIME_FORMAT = "%H:%M:%S.%f"
120+
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
121+
122+
def escape_args(self, parameters):
123+
if isinstance(parameters, dict):
124+
return {k: self.escape_item(v) for k, v in parameters.items()}
125+
elif isinstance(parameters, (list, tuple)):
126+
return tuple(self.escape_item(x) for x in parameters)
127+
else:
128+
raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
129+
130+
def escape_number(self, item):
131+
return item
132+
133+
def escape_string(self, item):
134+
# Need to decode UTF-8 because of old sqlalchemy.
135+
# Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
136+
# as byte strings. The old version always encodes Unicode as byte strings, which breaks
137+
# string formatting here.
138+
if isinstance(item, bytes):
139+
item = item.decode('utf-8')
140+
# This is good enough when backslashes are literal, newlines are just followed, and the way
141+
# to escape a single quote is to put two single quotes.
142+
# (i.e. only special character is single quote)
143+
return "'{}'".format(item.replace("'", "''"))
144+
145+
def escape_sequence(self, item):
146+
l = map(str, map(self.escape_item, item))
147+
return '(' + ','.join(l) + ')'
148+
149+
def escape_datetime(self, item, format, cutoff=0):
150+
dt_str = item.strftime(format)
151+
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
152+
return "'{}'".format(formatted)
153+
154+
def escape_item(self, item):
155+
if item is None:
156+
return 'NULL'
157+
elif isinstance(item, (int, float)):
158+
return self.escape_number(item)
159+
elif isinstance(item, str):
160+
return self.escape_string(item)
161+
elif isinstance(item, Iterable):
162+
return self.escape_sequence(item)
163+
elif isinstance(item, datetime.datetime):
164+
return self.escape_datetime(item, self._DATETIME_FORMAT)
165+
elif isinstance(item, datetime.date):
166+
return self.escape_datetime(item, self._DATE_FORMAT)
167+
else:
168+
raise exc.ProgrammingError("Unsupported object {}".format(item))

0 commit comments

Comments
 (0)