Skip to content

Commit d00fc9d

Browse files
NiallEgansusodapop
authored andcommitted
Add session state properties passthrough
This PR adds the `session_configuration` parameter to allow users to set confs on the session.
1 parent 2dd0f9a commit d00fc9d

File tree

4 files changed

+64
-11
lines changed

4 files changed

+64
-11
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self,
2424
http_path: str,
2525
access_token: str,
2626
metadata: Optional[List[Tuple[str, str]]] = None,
27+
session_configuration: Dict[str, Any] = None,
2728
**kwargs) -> None:
2829
"""
2930
Connect to a Databricks SQL endpoint or a Databricks cluster.
@@ -33,6 +34,8 @@ def __init__(self,
3334
or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
3435
:param access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
3536
:param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request
37+
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
38+
Execute the SQL command `SET -v` to get a full list of available commands.
3639
"""
3740

3841
# Internal arguments in **kwargs:
@@ -82,7 +85,7 @@ def __init__(self,
8285
self.thrift_backend = ThriftBackend(self.host, self.port, http_path,
8386
(metadata or []) + base_headers, **kwargs)
8487

85-
self._session_handle = self.thrift_backend.open_session()
88+
self._session_handle = self.thrift_backend.open_session(session_configuration)
8689
self.open = True
8790
logger.info("Successfully opened session " + str(self.get_session_id()))
8891
self._cursors = [] # type: List[Cursor]

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
DATABRICKS_ERROR_OR_REDIRECT_HEADER = "x-databricks-error-or-redirect-message"
2424
DATABRICKS_REASON_HEADER = "x-databricks-reason-phrase"
2525

26+
TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString"
27+
2628
# see Connection.__init__ for parameter descriptions.
2729
# - Min/Max avoids unsustainable configs (sane values are far more constrained)
2830
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
@@ -278,18 +280,27 @@ def _check_protocol_version(self, t_open_session_resp):
278280
"SPARK_CLI_SERVICE_PROTOCOL_V3, "
279281
"instead got: {}".format(protocol_version))
280282

281-
def open_session(self):
283+
def _check_session_configuration(self, session_configuration):
284+
# This client expects timetampsAsString to be false, so we do not allow users to modify that
285+
if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false":
286+
raise Error("Invalid session configuration: {} cannot be changed "
287+
"while using the Databricks SQL connector, it must be false not {}".format(
288+
TIMESTAMP_AS_STRING_CONFIG,
289+
session_configuration[TIMESTAMP_AS_STRING_CONFIG]))
290+
291+
def open_session(self, session_configuration):
282292
try:
283293
self._transport.open()
294+
session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()}
295+
self._check_session_configuration(session_configuration)
296+
# We want to receive proper Timestamp arrow types.
297+
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
298+
# but it doesn't hurt to also set for the whole session.
299+
session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false"
284300
open_session_req = ttypes.TOpenSessionReq(
285301
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
286302
client_protocol=None,
287-
configuration={
288-
# We want to receive proper Timestamp arrow types.
289-
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
290-
# but it doesn't hurt to also set for the whole session.
291-
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
292-
})
303+
configuration=session_configuration)
293304
response = self.make_request(self._client.OpenSession, open_session_req)
294305
self._check_protocol_version(response)
295306
return response.sessionHandle

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass):
117117

118118
with self.assertRaises(OperationalError) as cm:
119119
thrift_backend = self._make_fake_thrift_backend()
120-
thrift_backend.open_session()
120+
thrift_backend.open_session({})
121121

122122
self.assertIn("expected server to use a protocol version", str(cm.exception))
123123

@@ -134,7 +134,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass):
134134
status=self.okay_status, serverProtocolVersion=protocol_version)
135135

136136
thrift_backend = self._make_fake_thrift_backend()
137-
thrift_backend.open_session()
137+
thrift_backend.open_session({})
138138

139139
@patch("thrift.transport.THttpClient.THttpClient")
140140
def test_headers_are_set(self, t_http_client_class):
@@ -664,7 +664,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class
664664
tcli_service_instance.OpenSession.return_value = self.open_session_resp
665665

666666
thrift_backend = ThriftBackend("foobar", 443, "path", [])
667-
thrift_backend.open_session()
667+
thrift_backend.open_session({})
668668
self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1)
669669

670670
@patch("databricks.sql.thrift_backend.TCLIService.Client")
@@ -1035,6 +1035,36 @@ def test_retry_args_bounding(self, mock_http_client):
10351035
for (arg, val) in retry_delay_expected_vals.items():
10361036
self.assertEqual(getattr(backend, arg), val)
10371037

1038+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1039+
def test_configuration_passthrough(self, tcli_client_class):
1040+
tcli_service_instance = tcli_client_class.return_value
1041+
tcli_service_instance.OpenSession.return_value = self.open_session_resp
1042+
mock_config = {"foo": "bar", "baz": True, "42": 42}
1043+
expected_config = {
1044+
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false",
1045+
"foo": "bar",
1046+
"baz": "True",
1047+
"42": "42"
1048+
}
1049+
1050+
backend = ThriftBackend("foobar", 443, "path", [])
1051+
backend.open_session(mock_config)
1052+
1053+
open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0]
1054+
self.assertEqual(open_session_req.configuration, expected_config)
1055+
1056+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1057+
def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class):
1058+
tcli_service_instance = tcli_client_class.return_value
1059+
tcli_service_instance.OpenSession.return_value = self.open_session_resp
1060+
mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True}
1061+
backend = ThriftBackend("foobar", 443, "path", [])
1062+
1063+
with self.assertRaises(databricks.sql.Error) as cm:
1064+
backend.open_session(mock_config)
1065+
1066+
self.assertIn("timestampAsString cannot be changed", str(cm.exception))
1067+
10381068

10391069
if __name__ == '__main__':
10401070
unittest.main()

cmdexec/clients/python/tests/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,15 @@ def test_version_is_canonical(self):
339339
r'(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$'
340340
self.assertIsNotNone(re.match(canonical_version_re, version))
341341

342+
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
343+
def test_configuration_passthrough(self, mock_client_class):
344+
mock_session_config = Mock()
345+
databricks.sql.connect(
346+
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS)
347+
348+
self.assertEqual(mock_client_class.return_value.open_session.call_args[0][0],
349+
mock_session_config)
350+
342351

343352
if __name__ == '__main__':
344353
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])

0 commit comments

Comments
 (0)