Skip to content

Commit 76695ea

Browse files
NiallEgansusodapop
authored andcommitted
Client-side Decimal deserialisation
This PR adds client-side deserialisation of Decimals - New unit tests - New integration tests
1 parent 91d2f89 commit 76695ea

File tree

7 files changed

+153
-64
lines changed

7 files changed

+153
-64
lines changed

cmdexec/clients/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
version="0.0.0",
66
package_dir={"": "src"},
77
packages=setuptools.find_packages(where="src"),
8-
install_requires=["pyarrow", 'thrift>=0.10.0'],
8+
install_requires=["pyarrow", 'thrift>=0.10.0', "pandas"],
99
author="Databricks",
1010
)

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1818
DEFAULT_ARRAY_SIZE = 100000
1919

20-
TYPES_CONVERTER = {
21-
"decimal": Decimal,
22-
}
23-
2420

2521
class Connection:
2622
def __init__(self,
@@ -440,7 +436,6 @@ def __init__(self,
440436
self.has_more_rows = execute_response.has_more_rows
441437
self.buffer_size_bytes = result_buffer_size_bytes
442438
self._row_index = 0
443-
self.description = None
444439
self.arraysize = arraysize
445440
self.thrift_backend = thrift_backend
446441
self.description = execute_response.description
@@ -465,24 +460,14 @@ def __iter__(self):
465460
def _fill_results_buffer(self):
466461
results, has_more_rows = self.thrift_backend.fetch_results(
467462
self.command_id, self.arraysize, self.buffer_size_bytes, self._row_index,
468-
self._arrow_schema)
463+
self._arrow_schema, self.description)
469464
self.results = results
470465
self.has_more_rows = has_more_rows
471466

472-
@staticmethod
473-
def parse_type(type_, value):
474-
converter = TYPES_CONVERTER.get(type_)
475-
if converter:
476-
return value if value is None else converter(value)
477-
else:
478-
return value
479-
480467
def _convert_arrow_table(self, table):
481468
n_rows, _ = table.shape
482-
list_repr = [[
483-
self.parse_type(self.description[col_index][1], col[row_index].as_py())
484-
for col_index, col in enumerate(table.itercolumns())
485-
] for row_index in range(n_rows)]
469+
list_repr = [[col[row_index].as_py() for col in table.itercolumns()]
470+
for row_index in range(n_rows)]
486471
return list_repr
487472

488473
def fetchmany_arrow(self, n_rows: int) -> pyarrow.Table:

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from decimal import Decimal
12
import logging
23
import time
34
import threading
@@ -190,13 +191,32 @@ def _poll_for_status(self, op_handle):
190191
)
191192
return self.make_request(self._client.GetOperationStatus, req)
192193

193-
def _create_arrow_table(self, t_row_set, schema):
194+
def _create_arrow_table(self, t_row_set, arrow_schema, description):
194195
if t_row_set.columns is not None:
195-
return ThriftBackend._convert_column_based_set_to_arrow_table(t_row_set.columns, schema)
196-
if t_row_set.arrowBatches is not None:
197-
return ThriftBackend._convert_arrow_based_set_to_arrow_table(
198-
t_row_set.arrowBatches, schema)
199-
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
196+
arrow_table, num_rows = ThriftBackend._convert_column_based_set_to_arrow_table(
197+
t_row_set.columns, arrow_schema)
198+
elif t_row_set.arrowBatches is not None:
199+
arrow_table, num_rows = ThriftBackend._convert_arrow_based_set_to_arrow_table(
200+
t_row_set.arrowBatches, arrow_schema)
201+
else:
202+
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
203+
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows
204+
205+
@staticmethod
206+
def _convert_decimals_in_arrow_table(table, description):
207+
for (i, col) in enumerate(table.itercolumns()):
208+
if description[i][1] == 'decimal':
209+
decimal_col = col.to_pandas().apply(lambda v: v if v is None else Decimal(v))
210+
precision, scale = description[i][4], description[i][5]
211+
assert scale is not None
212+
assert precision is not None
213+
# Spark limits decimal to a maximum scale of 38,
214+
# so 128 is guaranteed to be big enough
215+
dtype = pyarrow.decimal128(precision, scale)
216+
col_data = pyarrow.array(decimal_col, type=dtype)
217+
field = table.field(i).with_type(dtype)
218+
table = table.set_column(i, field, col_data)
219+
return table
200220

201221
@staticmethod
202222
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema):
@@ -304,17 +324,32 @@ def convert_col(t_column_desc):
304324
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
305325

306326
@staticmethod
307-
def _hive_schema_to_description(t_table_schema):
308-
def clean_type(typeEntry):
309-
if typeEntry.primitiveEntry:
310-
name = ttypes.TTypeId._VALUES_TO_NAMES[typeEntry.primitiveEntry.type]
311-
# Drop _TYPE suffix
312-
return (name[:-5] if name.endswith("_TYPE") else name).lower()
327+
def _col_to_description(col):
328+
type_entry = col.typeDesc.types[0]
329+
330+
if type_entry.primitiveEntry:
331+
name = ttypes.TTypeId._VALUES_TO_NAMES[type_entry.primitiveEntry.type]
332+
# Drop _TYPE suffix
333+
cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower()
334+
else:
335+
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
336+
337+
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
338+
qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers
339+
if qualifiers and "precision" in qualifiers and "scale" in qualifiers:
340+
precision, scale = qualifiers["precision"].i32Value, qualifiers["scale"].i32Value
313341
else:
314-
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
342+
raise OperationalError(
343+
"Decimal type did not provide typeQualifier precision, scale in "
344+
"primitiveEntry {}".format(type_entry.primitiveEntry))
345+
else:
346+
precision, scale = None, None
315347

316-
return [(col.columnName, clean_type(col.typeDesc.types[0]), None, None, None, None, None)
317-
for col in t_table_schema.columns]
348+
return col.columnName, cleaned_type, None, None, precision, scale, None
349+
350+
@staticmethod
351+
def _hive_schema_to_description(t_table_schema):
352+
return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns]
318353

319354
def _results_message_to_execute_response(self, resp, operation_state):
320355
if resp.directResults and resp.directResults.resultSetMetadata:
@@ -341,7 +376,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
341376
assert (direct_results.resultSet.results.startRowOffset == 0)
342377
assert (direct_results.resultSetMetadata)
343378
arrow_results, n_rows = self._create_arrow_table(direct_results.resultSet.results,
344-
arrow_schema)
379+
arrow_schema, description)
345380
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
346381
else:
347382
arrow_queue_opt = None
@@ -477,7 +512,7 @@ def _handle_execute_response(self, resp, cursor):
477512

478513
return self._results_message_to_execute_response(resp, final_operation_state)
479514

480-
def fetch_results(self, op_handle, max_rows, max_bytes, row_offset, arrow_schema):
515+
def fetch_results(self, op_handle, max_rows, max_bytes, row_offset, arrow_schema, description):
481516
assert (op_handle is not None)
482517

483518
req = ttypes.TFetchResultsReq(
@@ -493,7 +528,7 @@ def fetch_results(self, op_handle, max_rows, max_bytes, row_offset, arrow_schema
493528
)
494529

495530
resp = self.make_request(self._client.FetchResults, req)
496-
arrow_results, n_rows = self._create_arrow_table(resp.results, arrow_schema)
531+
arrow_results, n_rows = self._create_arrow_table(resp.results, arrow_schema, description)
497532
arrow_queue = ArrowQueue(arrow_results, n_rows, row_offset - resp.results.startRowOffset)
498533

499534
return arrow_queue, resp.hasMoreRows

cmdexec/clients/python/test-container-with-reqs.dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ RUN pip install grpcio==1.41.0 \
44
pyarrow==5.0.0 \
55
protobuf==3.18.1 \
66
cryptography==35.0.0 \
7-
thrift==0.13.0
7+
thrift==0.13.0 \
8+
pandas==1.3.4
89

910
ENTRYPOINT ["./docker-entrypoint.sh"]

cmdexec/clients/python/tests/test_fetches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
5151
def make_dummy_result_set_from_batch_list(batch_list):
5252
batch_index = 0
5353

54-
def fetch_results(op_handle, max_rows, max_bytes, row_offset, arrow_schema):
54+
def fetch_results(op_handle, max_rows, max_bytes, row_offset, arrow_schema, description):
5555
nonlocal batch_index
5656
results = FetchTests.make_arrow_queue(batch_list[batch_index])
5757
batch_index += 1

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
from collections import OrderedDict
2+
from decimal import Decimal
3+
import itertools
14
import unittest
25
from unittest.mock import patch, MagicMock, Mock
3-
import itertools
6+
47
import pyarrow
58

69
from databricks.sql.thrift_api.TCLIService import ttypes
7-
810
from databricks.sql import *
911
from databricks.sql.thrift_backend import ThriftBackend
1012

1113

12-
class TestThriftBackend(unittest.TestCase):
14+
class ThriftBackendTestSuite(unittest.TestCase):
1315
okay_status = ttypes.TStatus(statusCode=ttypes.TStatusCode.SUCCESS_STATUS)
1416

1517
bad_status = ttypes.TStatus(
@@ -55,7 +57,7 @@ def _make_fake_thrift_backend(self):
5557
thrift_backend._hive_schema_to_arrow_schema = Mock()
5658
thrift_backend._hive_schema_to_description = Mock()
5759
thrift_backend._create_arrow_table = MagicMock()
58-
thrift_backend._create_arrow_table.return_value = (Mock(), Mock())
60+
thrift_backend._create_arrow_table.return_value = (MagicMock(), Mock())
5961
return thrift_backend
6062

6163
def test_hive_schema_to_arrow_schema_preserves_column_names(self):
@@ -173,6 +175,28 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self):
173175
("", "struct", None, None, None, None, None),
174176
])
175177

178+
def test_hive_schema_to_description_preserves_scale_and_precision(self):
179+
columns = [
180+
ttypes.TColumnDesc(
181+
columnName="column 1",
182+
typeDesc=ttypes.TTypeDesc(types=[
183+
ttypes.TTypeEntry(
184+
ttypes.TPrimitiveTypeEntry(
185+
type=ttypes.TTypeId.DECIMAL_TYPE,
186+
typeQualifiers=ttypes.TTypeQualifiers(
187+
qualifiers={
188+
"precision": ttypes.TTypeQualifierValue(i32Value=10),
189+
"scale": ttypes.TTypeQualifierValue(i32Value=100),
190+
})))
191+
])),
192+
]
193+
t_table_schema = ttypes.TTableSchema(columns)
194+
195+
description = ThriftBackend._hive_schema_to_description(t_table_schema)
196+
self.assertEqual(description, [
197+
("column 1", "decimal", None, None, 10, 100, None),
198+
])
199+
176200
def test_make_request_checks_status_code(self):
177201
error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS]
178202
thrift_backend = ThriftBackend("foo", 123, "bar", [])
@@ -390,7 +414,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
390414
self.execute_response_types):
391415
with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type):
392416
tcli_service_instance = tcli_service_class.return_value
393-
results_mock = Mock()
417+
results_mock = MagicMock()
394418
results_mock.startRowOffset = 0
395419

396420
execute_resp = resp_type(
@@ -415,7 +439,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
415439
thrift_backend = self._make_fake_thrift_backend()
416440

417441
thrift_backend._handle_execute_response(execute_resp, Mock())
418-
_, has_more_rows_resp = thrift_backend.fetch_results(Mock(), 1, 1, 0, Mock())
442+
_, has_more_rows_resp = thrift_backend.fetch_results(Mock(), 1, 1, 0, Mock(),
443+
Mock())
419444

420445
self.assertEqual(has_more_rows, has_more_rows_resp)
421446

@@ -603,25 +628,27 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self):
603628
t_row_set = ttypes.TRowSet()
604629
thrift_backend = ThriftBackend("foobar", 443, "path", [])
605630
with self.assertRaises(OperationalError):
606-
thrift_backend._create_arrow_table(t_row_set, None)
631+
thrift_backend._create_arrow_table(t_row_set, None, Mock())
607632

608633
@patch.object(ThriftBackend, "_convert_arrow_based_set_to_arrow_table")
609634
@patch.object(ThriftBackend, "_convert_column_based_set_to_arrow_table")
610635
def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock,
611636
convert_arrow_mock):
612637
thrift_backend = ThriftBackend("foobar", 443, "path", [])
638+
convert_arrow_mock.return_value = (MagicMock(), Mock())
639+
convert_col_mock.return_value = (MagicMock(), Mock())
613640

614641
schema = Mock()
615642
cols = Mock()
616643
arrow_batches = Mock()
617644

618645
t_col_set = ttypes.TRowSet(columns=cols)
619-
thrift_backend._create_arrow_table(t_col_set, schema)
646+
thrift_backend._create_arrow_table(t_col_set, schema, Mock())
620647
convert_arrow_mock.assert_not_called()
621648
convert_col_mock.assert_called_once_with(cols, schema)
622649

623650
t_arrow_set = ttypes.TRowSet(arrowBatches=arrow_batches)
624-
thrift_backend._create_arrow_table(t_arrow_set, schema)
651+
thrift_backend._create_arrow_table(t_arrow_set, schema, Mock())
625652
convert_arrow_mock.assert_called_once_with(arrow_batches, schema)
626653
convert_col_mock.assert_called_once_with(cols, schema)
627654

@@ -818,6 +845,60 @@ def test_make_request_will_read_X_Thriftserver_Error_Message_if_set(self, t_tran
818845

819846
self.assertEqual(mock_method.call_count, 13 + 1)
820847

848+
@staticmethod
849+
def make_table_and_desc(height, n_decimal_cols, width, precision, scale, int_constant,
850+
decimal_constant):
851+
int_col = [int_constant for _ in range(height)]
852+
decimal_col = [decimal_constant for _ in range(height)]
853+
data = OrderedDict({"col{}".format(i): int_col for i in range(width - n_decimal_cols)})
854+
decimals = OrderedDict({"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)})
855+
data.update(decimals)
856+
857+
int_desc = ([("", "int")] * (width - n_decimal_cols))
858+
decimal_desc = ([("", "decimal", None, None, precision, scale, None)] * n_decimal_cols)
859+
description = int_desc + decimal_desc
860+
861+
table = pyarrow.Table.from_pydict(data)
862+
return table, description
863+
864+
def test_arrow_decimal_conversion(self):
865+
# Broader tests in DecimalTestSuite
866+
width = 10
867+
int_constant = 12345
868+
precision, scale = 10, 5
869+
decimal_constant = "1.345"
870+
871+
for n_decimal_cols in [0, 1, 10]:
872+
for height in [0, 1, 10]:
873+
with self.subTest(n_decimal_cols=n_decimal_cols, height=height):
874+
table, description = self.make_table_and_desc(height, n_decimal_cols, width,
875+
precision, scale, int_constant,
876+
decimal_constant)
877+
decimal_converted_table = ThriftBackend._convert_decimals_in_arrow_table(
878+
table, description)
879+
880+
for i in range(width):
881+
if height > 0:
882+
if i < width - n_decimal_cols:
883+
self.assertEqual(
884+
decimal_converted_table.field(i).type, pyarrow.int64())
885+
else:
886+
self.assertEqual(
887+
decimal_converted_table.field(i).type,
888+
pyarrow.decimal128(precision=precision, scale=scale))
889+
890+
int_col = [int_constant for _ in range(height)]
891+
decimal_col = [Decimal(decimal_constant) for _ in range(height)]
892+
expected_result = OrderedDict(
893+
{"col{}".format(i): int_col
894+
for i in range(width - n_decimal_cols)})
895+
decimals = OrderedDict(
896+
{"col_dec{}".format(i): decimal_col
897+
for i in range(n_decimal_cols)})
898+
expected_result.update(decimals)
899+
900+
self.assertEqual(decimal_converted_table.to_pydict(), expected_result)
901+
821902

822903
if __name__ == '__main__':
823904
unittest.main()

cmdexec/clients/python/tests/tests.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from databricks.sql import InterfaceError, DatabaseError, Error
1111

1212
from cmdexec.clients.python.tests.test_fetches import FetchTests
13-
from cmdexec.clients.python.tests.test_thrift_backend import TestThriftBackend
13+
from cmdexec.clients.python.tests.test_thrift_backend import ThriftBackendTestSuite
1414

1515

16-
class ClientTests(unittest.TestCase):
16+
class ClientTestSuite(unittest.TestCase):
1717
"""
1818
Unit tests for isolated client behaviour. See
1919
qa/test/cmdexec/python/suites/simple_connection_test.py for integration tests that
@@ -332,23 +332,10 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
332332
self.assertEqual(mock_client_class.call_args[1]["_max_number_of_retries"], 53)
333333

334334

335-
class ResultSetTests(unittest.TestCase):
336-
def test_parse_type_converts_decimal(self):
337-
for input in [None, 0, "0", 5, "5", 2.33, "2.33"]:
338-
with self.subTest(input=input):
339-
res = client.ResultSet.parse_type("decimal", input)
340-
if input != None:
341-
self.assertEqual(type(res), Decimal)
342-
self.assertEqual(res, Decimal(input))
343-
else:
344-
self.assertEqual(type(res), type(None))
345-
self.assertEqual(res, None)
346-
347-
348335
if __name__ == '__main__':
349336
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
350337
loader = unittest.TestLoader()
351-
test_classes = [ClientTests, ResultSetTests, FetchTests, TestThriftBackend]
338+
test_classes = [ClientTestSuite, FetchTests, ThriftBackendTestSuite]
352339
suites_list = []
353340
for test_class in test_classes:
354341
suite = loader.loadTestsFromTestCase(test_class)

0 commit comments

Comments
 (0)