11import base64
2+ import datetime
3+ from decimal import Decimal
24import logging
35import re
46import time
79import grpc
810import pyarrow
911
10- from databricks .sql .errors import OperationalError , InterfaceError , DatabaseError , Error
12+ from databricks .sql .errors import OperationalError , InterfaceError , DatabaseError , Error , DataError
1113from databricks .sql .api import messages_pb2
1214from databricks .sql .api .sql_cmd_service_pb2_grpc import SqlCommandServiceStub
1315from databricks .sql import USER_AGENT_NAME , __version__
1416
1517logger = logging .getLogger (__name__ )
1618
1719DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
20+ DEFAULT_ARRAY_SIZE = 100000
21+
22+ _TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
23+
24+ TYPES_CONVERTER = {"decimal" : Decimal }
1825
1926
2027class Connection :
@@ -101,11 +108,13 @@ def _http_path_to_routing_header(self, http_path):
101108 else :
102109 raise ValueError ("Please provide a valid http_path" )
103110
104- def cursor (self , buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
111+ def cursor (self ,
112+ arraysize = DEFAULT_ARRAY_SIZE ,
113+ buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
105114 if not self .open :
106115 raise Error ("Cannot create cursor from closed connection" )
107116
108- cursor = Cursor (self , buffer_size_bytes )
117+ cursor = Cursor (self , arraysize = arraysize , result_buffer_size_bytes = buffer_size_bytes )
109118 self ._cursors .append (cursor )
110119 return cursor
111120
@@ -122,12 +131,12 @@ class Cursor:
122131 def __init__ (self ,
123132 connection ,
124133 result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ,
125- arraysize = 10000 ):
134+ arraysize = DEFAULT_ARRAY_SIZE ):
126135 self .connection = connection
127136 self .rowcount = - 1
128- self .arraysize = arraysize
129137 self .buffer_size_bytes = result_buffer_size_bytes
130138 self .active_result_set = None
139+ self .arraysize = arraysize
131140 # Note that Cursor closed => active result set closed, but not vice versa
132141 self .open = True
133142 self .executing_command_id = None
@@ -351,9 +360,19 @@ def _fill_results_buffer(self):
351360 self .description = description
352361
353362 @staticmethod
354- def _convert_arrow_table (table ):
363+ def parse_type (type_ , value ):
364+ converter = TYPES_CONVERTER .get (type_ )
365+ if converter :
366+ return converter (value )
367+ else :
368+ return value
369+
370+ def _convert_arrow_table (self , table ):
355371 n_rows , _ = table .shape
356- list_repr = [[col [i ].as_py () for col in table .itercolumns ()] for i in range (n_rows )]
372+ list_repr = [[
373+ self .parse_type (self .description [col_index ][1 ], col [row_index ].as_py ())
374+ for col_index , col in enumerate (table .itercolumns ())
375+ ] for row_index in range (n_rows )]
357376 return list_repr
358377
359378 def fetchmany_arrow (self , n_rows ):
@@ -397,7 +416,6 @@ def fetchone(self):
397416 Fetch the next row of a query result set, returning a single sequence,
398417 or None when no more data is available.
399418 """
400- self ._row_index += 1
401419 res = self ._convert_arrow_table (self .fetchmany_arrow (1 ))
402420 if len (res ) > 0 :
403421 return res [0 ]
0 commit comments