Skip to content

Commit 86b9055

Browse files
NiallEgansusodapop
authored andcommitted
Add initial namespace to open session
This PR adds initial catalog and schema parameters to `connect`. Unit tests: * Passthrough tests * Test canUseMultipleCatalogs * Test error behaviour when incorrect namespsace returned * Test protocol version checks vs initial namespsace
1 parent f8a4799 commit 86b9055

File tree

5 files changed

+135
-8
lines changed

5 files changed

+135
-8
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(self,
2525
access_token: str,
2626
metadata: Optional[List[Tuple[str, str]]] = None,
2727
session_configuration: Dict[str, Any] = None,
28+
catalog: Optional[str] = None,
29+
schema: Optional[str] = None,
2830
**kwargs) -> None:
2931
"""
3032
Connect to a Databricks SQL endpoint or a Databricks cluster.
@@ -36,6 +38,8 @@ def __init__(self,
3638
:param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request
3739
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
3840
Execute the SQL command `SET -v` to get a full list of available commands.
41+
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
42+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
3943
"""
4044

4145
# Internal arguments in **kwargs:
@@ -88,7 +92,8 @@ def __init__(self,
8892
self.thrift_backend = ThriftBackend(self.host, self.port, http_path,
8993
(metadata or []) + base_headers, **kwargs)
9094

91-
self._session_handle = self.thrift_backend.open_session(session_configuration)
95+
self._session_handle = self.thrift_backend.open_session(session_configuration, catalog,
96+
schema)
9297
self.open = True
9398
logger.info("Successfully opened session " + str(self.get_session_id()))
9499
self._cursors = [] # type: List[Cursor]

cmdexec/clients/python/src/databricks/sql/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
### PEP-249 Mandated ###
12
class Error(Exception):
23
pass
34

@@ -36,3 +37,9 @@ class DataError(DatabaseError):
3637

3738
class NotSupportedError(DatabaseError):
3839
pass
40+
41+
42+
### Custom error classes ###
43+
class InvalidServerResponseError(OperationalError):
44+
""" Thrown if the server does not set the initial namespace correctly"""
45+
pass

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,22 @@ def _check_protocol_version(self, t_open_session_resp):
289289
"SPARK_CLI_SERVICE_PROTOCOL_V3, "
290290
"instead got: {}".format(protocol_version))
291291

292+
def _check_initial_namespace(self, catalog, schema, response):
293+
if not (catalog or schema):
294+
return
295+
296+
if response.serverProtocolVersion < \
297+
ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4:
298+
raise InvalidServerResponseError(
299+
"Setting initial namespace not supported by the DBR version, "
300+
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.")
301+
302+
if catalog:
303+
if not response.canUseMultipleCatalogs:
304+
raise InvalidServerResponseError(
305+
"Unexpected response from server: Trying to set initial catalog to {}, " +
306+
"but server does not support multiple catalogs.".format(catalog))
307+
292308
def _check_session_configuration(self, session_configuration):
293309
# This client expects timetampsAsString to be false, so we do not allow users to modify that
294310
if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false":
@@ -297,7 +313,7 @@ def _check_session_configuration(self, session_configuration):
297313
TIMESTAMP_AS_STRING_CONFIG,
298314
session_configuration[TIMESTAMP_AS_STRING_CONFIG]))
299315

300-
def open_session(self, session_configuration):
316+
def open_session(self, session_configuration, catalog, schema):
301317
try:
302318
self._transport.open()
303319
session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()}
@@ -306,11 +322,19 @@ def open_session(self, session_configuration):
306322
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
307323
# but it doesn't hurt to also set for the whole session.
308324
session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false"
325+
if catalog or schema:
326+
initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema)
327+
else:
328+
initial_namespace = None
329+
309330
open_session_req = ttypes.TOpenSessionReq(
310331
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
311332
client_protocol=None,
333+
initialNamespace=initial_namespace,
334+
canUseMultipleCatalogs=True,
312335
configuration=session_configuration)
313336
response = self.make_request(self._client.OpenSession, open_session_req)
337+
self._check_initial_namespace(catalog, schema, response)
314338
self._check_protocol_version(response)
315339
return response.sessionHandle
316340
except:

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass):
117117

118118
with self.assertRaises(OperationalError) as cm:
119119
thrift_backend = self._make_fake_thrift_backend()
120-
thrift_backend.open_session({})
120+
thrift_backend.open_session({}, None, None)
121121

122122
self.assertIn("expected server to use a protocol version", str(cm.exception))
123123

@@ -134,7 +134,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass):
134134
status=self.okay_status, serverProtocolVersion=protocol_version)
135135

136136
thrift_backend = self._make_fake_thrift_backend()
137-
thrift_backend.open_session({})
137+
thrift_backend.open_session({}, None, None)
138138

139139
@patch("thrift.transport.THttpClient.THttpClient")
140140
def test_headers_are_set(self, t_http_client_class):
@@ -673,7 +673,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class
673673
tcli_service_instance.OpenSession.return_value = self.open_session_resp
674674

675675
thrift_backend = ThriftBackend("foobar", 443, "path", [])
676-
thrift_backend.open_session({})
676+
thrift_backend.open_session({}, None, None)
677677
self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1)
678678

679679
@patch("databricks.sql.thrift_backend.TCLIService.Client")
@@ -1057,7 +1057,7 @@ def test_configuration_passthrough(self, tcli_client_class):
10571057
}
10581058

10591059
backend = ThriftBackend("foobar", 443, "path", [])
1060-
backend.open_session(mock_config)
1060+
backend.open_session(mock_config, None, None)
10611061

10621062
open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0]
10631063
self.assertEqual(open_session_req.configuration, expected_config)
@@ -1070,10 +1070,92 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class):
10701070
backend = ThriftBackend("foobar", 443, "path", [])
10711071

10721072
with self.assertRaises(databricks.sql.Error) as cm:
1073-
backend.open_session(mock_config)
1073+
backend.open_session(mock_config, None, None)
10741074

10751075
self.assertIn("timestampAsString cannot be changed", str(cm.exception))
10761076

1077+
def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, schem):
1078+
return ttypes.TOpenSessionResp(
1079+
status=self.okay_status,
1080+
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
1081+
canUseMultipleCatalogs=can_use_multiple_cats,
1082+
initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem))
1083+
1084+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1085+
def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class):
1086+
tcli_service_instance = tcli_client_class.return_value
1087+
1088+
backend = ThriftBackend("foobar", 443, "path", [])
1089+
initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")]
1090+
1091+
for cat, schem in initial_cat_schem_args:
1092+
with self.subTest(cat=cat, schem=schem):
1093+
tcli_service_instance.OpenSession.return_value = \
1094+
self._construct_open_session_with_namespace(True, cat, schem)
1095+
1096+
backend.open_session({}, cat, schem)
1097+
1098+
open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0]
1099+
self.assertEqual(open_session_req.initialNamespace.catalogName, cat)
1100+
self.assertEqual(open_session_req.initialNamespace.schemaName, schem)
1101+
1102+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1103+
def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_class):
1104+
tcli_service_instance = tcli_client_class.return_value
1105+
tcli_service_instance.OpenSession.return_value = self.open_session_resp
1106+
1107+
backend = ThriftBackend("foobar", 443, "path", [])
1108+
backend.open_session({}, None, None)
1109+
1110+
open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0]
1111+
self.assertTrue(open_session_req.canUseMultipleCatalogs)
1112+
1113+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1114+
def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class):
1115+
tcli_service_instance = tcli_client_class.return_value
1116+
1117+
backend = ThriftBackend("foobar", 443, "path", [])
1118+
# If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we
1119+
# expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False
1120+
# is fine
1121+
failing_ns_args = [("cat", None), ("cat", "schem")]
1122+
passing_ns_args = [(None, None), (None, "schem")]
1123+
1124+
for cat, schem in failing_ns_args:
1125+
tcli_service_instance.OpenSession.return_value = \
1126+
self._construct_open_session_with_namespace(False, cat, schem)
1127+
1128+
with self.assertRaises(InvalidServerResponseError) as cm:
1129+
backend.open_session({}, cat, schem)
1130+
1131+
self.assertIn("server does not support multiple catalogs", str(cm.exception),
1132+
"incorrect error thrown for initial namespace {}".format((cat, schem)))
1133+
1134+
for cat, schem in passing_ns_args:
1135+
tcli_service_instance.OpenSession.return_value = \
1136+
self._construct_open_session_with_namespace(False, cat, schem)
1137+
backend.open_session({}, cat, schem)
1138+
1139+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1140+
def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
1141+
tcli_service_instance = tcli_client_class.return_value
1142+
1143+
tcli_service_instance.OpenSession.return_value = \
1144+
ttypes.TOpenSessionResp(
1145+
status=self.okay_status,
1146+
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3,
1147+
canUseMultipleCatalogs=True,
1148+
initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem")
1149+
)
1150+
1151+
backend = ThriftBackend("foobar", 443, "path", [])
1152+
1153+
with self.assertRaises(InvalidServerResponseError) as cm:
1154+
backend.open_session({}, "cat", "schem")
1155+
1156+
self.assertIn("Setting initial namespace not supported by the DBR version",
1157+
str(cm.exception))
1158+
10771159

10781160
if __name__ == '__main__':
10791161
unittest.main()

cmdexec/clients/python/tests/tests.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,16 @@ def test_configuration_passthrough(self, mock_client_class):
353353
self.assertEqual(mock_client_class.return_value.open_session.call_args[0][0],
354354
mock_session_config)
355355

356-
def test_execute_parameter_passhthrough(self):
356+
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
357+
def test_initial_namespace_passthrough(self, mock_client_class):
358+
mock_cat = Mock()
359+
mock_schem = Mock()
360+
361+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem)
362+
self.assertEqual(mock_client_class.return_value.open_session.call_args[0][1], mock_cat)
363+
self.assertEqual(mock_client_class.return_value.open_session.call_args[0][2], mock_schem)
364+
365+
def test_execute_parameter_passthrough(self):
357366
mock_thrift_backend = Mock()
358367
cursor = client.Cursor(Mock(), mock_thrift_backend)
359368

0 commit comments

Comments
 (0)