1+ from collections import OrderedDict
2+ from decimal import Decimal
3+ import itertools
14import unittest
25from unittest .mock import patch , MagicMock , Mock
3- import itertools
6+
47import pyarrow
58
69from databricks .sql .thrift_api .TCLIService import ttypes
7-
810from databricks .sql import *
911from databricks .sql .thrift_backend import ThriftBackend
1012
1113
12- class TestThriftBackend (unittest .TestCase ):
14+ class ThriftBackendTestSuite (unittest .TestCase ):
1315 okay_status = ttypes .TStatus (statusCode = ttypes .TStatusCode .SUCCESS_STATUS )
1416
1517 bad_status = ttypes .TStatus (
@@ -55,7 +57,7 @@ def _make_fake_thrift_backend(self):
5557 thrift_backend ._hive_schema_to_arrow_schema = Mock ()
5658 thrift_backend ._hive_schema_to_description = Mock ()
5759 thrift_backend ._create_arrow_table = MagicMock ()
58- thrift_backend ._create_arrow_table .return_value = (Mock (), Mock ())
60+ thrift_backend ._create_arrow_table .return_value = (MagicMock (), Mock ())
5961 return thrift_backend
6062
6163 def test_hive_schema_to_arrow_schema_preserves_column_names (self ):
@@ -173,6 +175,28 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self):
173175 ("" , "struct" , None , None , None , None , None ),
174176 ])
175177
178+ def test_hive_schema_to_description_preserves_scale_and_precision (self ):
179+ columns = [
180+ ttypes .TColumnDesc (
181+ columnName = "column 1" ,
182+ typeDesc = ttypes .TTypeDesc (types = [
183+ ttypes .TTypeEntry (
184+ ttypes .TPrimitiveTypeEntry (
185+ type = ttypes .TTypeId .DECIMAL_TYPE ,
186+ typeQualifiers = ttypes .TTypeQualifiers (
187+ qualifiers = {
188+ "precision" : ttypes .TTypeQualifierValue (i32Value = 10 ),
189+ "scale" : ttypes .TTypeQualifierValue (i32Value = 100 ),
190+ })))
191+ ])),
192+ ]
193+ t_table_schema = ttypes .TTableSchema (columns )
194+
195+ description = ThriftBackend ._hive_schema_to_description (t_table_schema )
196+ self .assertEqual (description , [
197+ ("column 1" , "decimal" , None , None , 10 , 100 , None ),
198+ ])
199+
176200 def test_make_request_checks_status_code (self ):
177201 error_codes = [ttypes .TStatusCode .ERROR_STATUS , ttypes .TStatusCode .INVALID_HANDLE_STATUS ]
178202 thrift_backend = ThriftBackend ("foo" , 123 , "bar" , [])
@@ -390,7 +414,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
390414 self .execute_response_types ):
391415 with self .subTest (has_more_rows = has_more_rows , resp_type = resp_type ):
392416 tcli_service_instance = tcli_service_class .return_value
393- results_mock = Mock ()
417+ results_mock = MagicMock ()
394418 results_mock .startRowOffset = 0
395419
396420 execute_resp = resp_type (
@@ -415,7 +439,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
415439 thrift_backend = self ._make_fake_thrift_backend ()
416440
417441 thrift_backend ._handle_execute_response (execute_resp , Mock ())
418- _ , has_more_rows_resp = thrift_backend .fetch_results (Mock (), 1 , 1 , 0 , Mock ())
442+ _ , has_more_rows_resp = thrift_backend .fetch_results (Mock (), 1 , 1 , 0 , Mock (),
443+ Mock ())
419444
420445 self .assertEqual (has_more_rows , has_more_rows_resp )
421446
@@ -603,25 +628,27 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self):
603628 t_row_set = ttypes .TRowSet ()
604629 thrift_backend = ThriftBackend ("foobar" , 443 , "path" , [])
605630 with self .assertRaises (OperationalError ):
606- thrift_backend ._create_arrow_table (t_row_set , None )
631+ thrift_backend ._create_arrow_table (t_row_set , None , Mock () )
607632
608633 @patch .object (ThriftBackend , "_convert_arrow_based_set_to_arrow_table" )
609634 @patch .object (ThriftBackend , "_convert_column_based_set_to_arrow_table" )
610635 def test_create_arrow_table_calls_correct_conversion_method (self , convert_col_mock ,
611636 convert_arrow_mock ):
612637 thrift_backend = ThriftBackend ("foobar" , 443 , "path" , [])
638+ convert_arrow_mock .return_value = (MagicMock (), Mock ())
639+ convert_col_mock .return_value = (MagicMock (), Mock ())
613640
614641 schema = Mock ()
615642 cols = Mock ()
616643 arrow_batches = Mock ()
617644
618645 t_col_set = ttypes .TRowSet (columns = cols )
619- thrift_backend ._create_arrow_table (t_col_set , schema )
646+ thrift_backend ._create_arrow_table (t_col_set , schema , Mock () )
620647 convert_arrow_mock .assert_not_called ()
621648 convert_col_mock .assert_called_once_with (cols , schema )
622649
623650 t_arrow_set = ttypes .TRowSet (arrowBatches = arrow_batches )
624- thrift_backend ._create_arrow_table (t_arrow_set , schema )
651+ thrift_backend ._create_arrow_table (t_arrow_set , schema , Mock () )
625652 convert_arrow_mock .assert_called_once_with (arrow_batches , schema )
626653 convert_col_mock .assert_called_once_with (cols , schema )
627654
@@ -818,6 +845,60 @@ def test_make_request_will_read_X_Thriftserver_Error_Message_if_set(self, t_tran
818845
819846 self .assertEqual (mock_method .call_count , 13 + 1 )
820847
848+ @staticmethod
849+ def make_table_and_desc (height , n_decimal_cols , width , precision , scale , int_constant ,
850+ decimal_constant ):
851+ int_col = [int_constant for _ in range (height )]
852+ decimal_col = [decimal_constant for _ in range (height )]
853+ data = OrderedDict ({"col{}" .format (i ): int_col for i in range (width - n_decimal_cols )})
854+ decimals = OrderedDict ({"col_dec{}" .format (i ): decimal_col for i in range (n_decimal_cols )})
855+ data .update (decimals )
856+
857+ int_desc = ([("" , "int" )] * (width - n_decimal_cols ))
858+ decimal_desc = ([("" , "decimal" , None , None , precision , scale , None )] * n_decimal_cols )
859+ description = int_desc + decimal_desc
860+
861+ table = pyarrow .Table .from_pydict (data )
862+ return table , description
863+
864+ def test_arrow_decimal_conversion (self ):
865+ # Broader tests in DecimalTestSuite
866+ width = 10
867+ int_constant = 12345
868+ precision , scale = 10 , 5
869+ decimal_constant = "1.345"
870+
871+ for n_decimal_cols in [0 , 1 , 10 ]:
872+ for height in [0 , 1 , 10 ]:
873+ with self .subTest (n_decimal_cols = n_decimal_cols , height = height ):
874+ table , description = self .make_table_and_desc (height , n_decimal_cols , width ,
875+ precision , scale , int_constant ,
876+ decimal_constant )
877+ decimal_converted_table = ThriftBackend ._convert_decimals_in_arrow_table (
878+ table , description )
879+
880+ for i in range (width ):
881+ if height > 0 :
882+ if i < width - n_decimal_cols :
883+ self .assertEqual (
884+ decimal_converted_table .field (i ).type , pyarrow .int64 ())
885+ else :
886+ self .assertEqual (
887+ decimal_converted_table .field (i ).type ,
888+ pyarrow .decimal128 (precision = precision , scale = scale ))
889+
890+ int_col = [int_constant for _ in range (height )]
891+ decimal_col = [Decimal (decimal_constant ) for _ in range (height )]
892+ expected_result = OrderedDict (
893+ {"col{}" .format (i ): int_col
894+ for i in range (width - n_decimal_cols )})
895+ decimals = OrderedDict (
896+ {"col_dec{}" .format (i ): decimal_col
897+ for i in range (n_decimal_cols )})
898+ expected_result .update (decimals )
899+
900+ self .assertEqual (decimal_converted_table .to_pydict (), expected_result )
901+
821902
822903if __name__ == '__main__' :
823904 unittest .main ()
0 commit comments