Skip to content

Commit 1151736

Browse files
Niall Egansusodapop
authored andcommitted
Testing changes
Author: Niall Egan <niall.egan@databricks.com>
1 parent 6588a32 commit 1151736

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

cmdexec/clients/python/src/databricks/sql/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
from databricks.sql.errors import *
2+
3+
14
class _DBAPITypeObject(object):
25
def __init__(self, *values):
36
self.values = values
47

58
def __eq__(self, other):
69
return other in self.values
710

11+
def __repr__(self):
12+
return "DBAPITypeObject(%s)" % self.values
13+
814

915
STRING = _DBAPITypeObject('string')
1016
BINARY = _DBAPITypeObject('binary')
11-
NUMBER = _DBAPITypeObject('boolean', 'byte', 'short', 'integer', 'long', 'double', 'decimal')
17+
NUMBER = _DBAPITypeObject('boolean', 'tinyint', 'smallint', 'int', 'bigint', 'float', 'double',
18+
'decimal')
1219
DATETIME = _DBAPITypeObject('timestamp')
1320
DATE = _DBAPITypeObject('date')
1421
ROWID = _DBAPITypeObject()

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import base64
2+
import datetime
3+
from decimal import Decimal
24
import logging
35
import re
46
import time
@@ -7,14 +9,19 @@
79
import grpc
810
import pyarrow
911

10-
from databricks.sql.errors import OperationalError, InterfaceError, DatabaseError, Error
12+
from databricks.sql.errors import OperationalError, InterfaceError, DatabaseError, Error, DataError
1113
from databricks.sql.api import messages_pb2
1214
from databricks.sql.api.sql_cmd_service_pb2_grpc import SqlCommandServiceStub
1315
from databricks.sql import USER_AGENT_NAME, __version__
1416

1517
logger = logging.getLogger(__name__)
1618

1719
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
20+
DEFAULT_ARRAY_SIZE = 100000
21+
22+
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
23+
24+
TYPES_CONVERTER = {"decimal": Decimal}
1825

1926

2027
class Connection:
@@ -101,11 +108,13 @@ def _http_path_to_routing_header(self, http_path):
101108
else:
102109
raise ValueError("Please provide a valid http_path")
103110

104-
def cursor(self, buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
111+
def cursor(self,
112+
arraysize=DEFAULT_ARRAY_SIZE,
113+
buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
105114
if not self.open:
106115
raise Error("Cannot create cursor from closed connection")
107116

108-
cursor = Cursor(self, buffer_size_bytes)
117+
cursor = Cursor(self, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes)
109118
self._cursors.append(cursor)
110119
return cursor
111120

@@ -122,12 +131,12 @@ class Cursor:
122131
def __init__(self,
123132
connection,
124133
result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES,
125-
arraysize=10000):
134+
arraysize=DEFAULT_ARRAY_SIZE):
126135
self.connection = connection
127136
self.rowcount = -1
128-
self.arraysize = arraysize
129137
self.buffer_size_bytes = result_buffer_size_bytes
130138
self.active_result_set = None
139+
self.arraysize = arraysize
131140
# Note that Cursor closed => active result set closed, but not vice versa
132141
self.open = True
133142
self.executing_command_id = None
@@ -351,9 +360,19 @@ def _fill_results_buffer(self):
351360
self.description = description
352361

353362
@staticmethod
354-
def _convert_arrow_table(table):
363+
def parse_type(type_, value):
364+
converter = TYPES_CONVERTER.get(type_)
365+
if converter:
366+
return converter(value)
367+
else:
368+
return value
369+
370+
def _convert_arrow_table(self, table):
355371
n_rows, _ = table.shape
356-
list_repr = [[col[i].as_py() for col in table.itercolumns()] for i in range(n_rows)]
372+
list_repr = [[
373+
self.parse_type(self.description[col_index][1], col[row_index].as_py())
374+
for col_index, col in enumerate(table.itercolumns())
375+
] for row_index in range(n_rows)]
357376
return list_repr
358377

359378
def fetchmany_arrow(self, n_rows):
@@ -397,7 +416,6 @@ def fetchone(self):
397416
Fetch the next row of a query result set, returning a single sequence,
398417
or None when no more data is available.
399418
"""
400-
self._row_index += 1
401419
res = self._convert_arrow_table(self.fetchmany_arrow(1))
402420
if len(res) > 0:
403421
return res[0]

cmdexec/clients/python/tests/test_fetches.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def make_arrow_queue(batch):
4040
def make_dummy_result_set_from_initial_results(initial_results):
4141
# If the initial results have been set, then we should never try and fetch more
4242
arrow_ipc_stream = FetchTests.make_arrow_ipc_stream(initial_results)
43-
return client.ResultSet(
43+
rs = client.ResultSet(
4444
connection=None,
4545
command_id=None,
4646
status=None,
@@ -49,6 +49,9 @@ def make_dummy_result_set_from_initial_results(initial_results):
4949
arrow_ipc_stream=arrow_ipc_stream,
5050
num_valid_rows=len(initial_results),
5151
schema_message=MagicMock())
52+
num_cols = len(initial_results[0]) if initial_results else 0
53+
rs.description = [('', 'integer', None, None, None, None, None) * num_cols]
54+
return rs
5255

5356
@staticmethod
5457
def make_dummy_result_set_from_batch_list(batch_list):
@@ -63,7 +66,10 @@ def _fetch_and_deserialize_results(self):
6366
return results, batch_index < len(batch_list), \
6467
[('id', 'integer', None, None, None, None, None)]
6568

66-
return SemiFakeResultSet(None, None, None, False, False)
69+
rs = SemiFakeResultSet(None, None, None, False, False)
70+
num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0
71+
rs.description = [('', 'integer', None, None, None, None, None) * num_cols]
72+
return rs
6773

6874
def test_fetchmany_with_initial_results(self):
6975
# Fetch all in one go

0 commit comments

Comments
 (0)