Skip to content

Commit 8b99176

Browse files
juliuszsompolskisusodapop
authored andcommitted
Make Arrow Timestamp physical type be returned by new Thrift client.
Pass the `spark.thriftserver.arrowBasedRowSet.timestampAsString` config to make Thriftserver return actual Timestamp type.
1 parent 5a6352e commit 8b99176

File tree

3 files changed

+13
-43
lines changed

3 files changed

+13
-43
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,8 @@
1717
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1818
DEFAULT_ARRAY_SIZE = 100000
1919

20-
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
21-
22-
23-
def _parse_timestamp(value):
24-
if type(value) is datetime.datetime:
25-
# The cmd exec server will return a datetime.datetime, so no further parsing is needed
26-
return value
27-
elif value:
28-
match = _TIMESTAMP_PATTERN.match(value)
29-
if match:
30-
if match.group(2):
31-
format = '%Y-%m-%d %H:%M:%S.%f'
32-
# use the pattern to truncate the value
33-
value = match.group()
34-
else:
35-
format = '%Y-%m-%d %H:%M:%S'
36-
value = datetime.datetime.strptime(value, format)
37-
return value
38-
else:
39-
raise Exception('Cannot convert "{}" into a datetime'.format(value))
40-
else:
41-
return None
42-
43-
4420
TYPES_CONVERTER = {
4521
"decimal": Decimal,
46-
"timestamp": _parse_timestamp,
4722
}
4823

4924

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,13 @@ def open_session(self, session_id=None):
155155
open_session_req = ttypes.TOpenSessionReq(
156156
sessionId=handle_identifier,
157157
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
158-
client_protocol=None)
158+
client_protocol=None,
159+
configuration={
160+
# We want to receive proper Timestamp arrow types.
161+
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
162+
# but it doesn't hurt to also set for the whole session.
163+
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
164+
})
159165
response = self.make_request(self._client.OpenSession, open_session_req)
160166
self._check_protocol_version(response)
161167
return response.sessionHandle
@@ -271,7 +277,7 @@ def map_type(t_type_entry):
271277
ttypes.TTypeId.FLOAT_TYPE: pyarrow.float32(),
272278
ttypes.TTypeId.DOUBLE_TYPE: pyarrow.float64(),
273279
ttypes.TTypeId.STRING_TYPE: pyarrow.string(),
274-
ttypes.TTypeId.TIMESTAMP_TYPE: pyarrow.string(),
280+
ttypes.TTypeId.TIMESTAMP_TYPE: pyarrow.timestamp('us', None),
275281
ttypes.TTypeId.BINARY_TYPE: pyarrow.binary(),
276282
ttypes.TTypeId.ARRAY_TYPE: pyarrow.string(),
277283
ttypes.TTypeId.MAP_TYPE: pyarrow.string(),
@@ -383,7 +389,11 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
383389
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
384390
canReadArrowResult=True,
385391
canDecompressLZ4Result=False,
386-
canDownloadResult=False)
392+
canDownloadResult=False,
393+
confOverlay={
394+
# We want to receive proper Timestamp arrow types.
395+
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
396+
})
387397
resp = self.make_request(self._client.ExecuteStatement, req)
388398
return self._handle_execute_response(resp, cursor)
389399

cmdexec/clients/python/tests/tests.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,21 +344,6 @@ def test_parse_type_converts_decimal(self):
344344
self.assertEqual(type(res), type(None))
345345
self.assertEqual(res, None)
346346

347-
def test_parse_type_converts_timestamp(self):
348-
# Supported formats: %Y-%m-%d %H:%M:%S.%f and %Y-%m-%d %H:%M:%S
349-
for input, output in [(None, None),
350-
("2021-11-15 11:18:32.349293",
351-
datetime(2021, 11, 15, 11, 18, 32, 349293)),
352-
("2021-11-15 11:18:32.3492931235",
353-
datetime(2021, 11, 15, 11, 18, 32, 349293)),
354-
("2021-11-15 11:18:32.34", datetime(2021, 11, 15, 11, 18, 32,
355-
340000)),
356-
("2021-11-15 11:18:32", datetime(2021, 11, 15, 11, 18, 32))]:
357-
with self.subTest(input=input, output=output):
358-
res = client.ResultSet.parse_type("timestamp", input)
359-
self.assertEqual(type(res), type(output))
360-
self.assertEqual(res, output)
361-
362347

363348
if __name__ == '__main__':
364349
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])

0 commit comments

Comments
 (0)