Skip to content

Commit 561c351

Browse files
update unit tests to address ThriftBackend through session instead of through Connection
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent b45871e commit 561c351

File tree

1 file changed

+91
-5
lines changed

1 file changed

+91
-5
lines changed

tests/unit/test_client.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,97 @@ class ClientTestSuite(unittest.TestCase):
8181
"access_token": "tok",
8282
}
8383

84-
@patch(
85-
"%s.session.ThriftDatabricksClient" % PACKAGE_NAME,
86-
ThriftDatabricksClientMockFactory.new(),
87-
)
88-
def test_closing_connection_closes_commands(self):
84+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
85+
def test_close_uses_the_correct_session_id(self, mock_client_class):
86+
instance = mock_client_class.return_value
87+
88+
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
89+
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
90+
instance.open_session.return_value = mock_open_session_resp
91+
92+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
93+
connection.close()
94+
95+
# Check the close session request has an id of x22
96+
close_session_id = instance.close_session.call_args[0][0].sessionId
97+
self.assertEqual(close_session_id, b"\x22")
98+
99+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
100+
def test_auth_args(self, mock_client_class):
101+
# Test that the following auth args work:
102+
# token = foo,
103+
# token = None, _tls_client_cert_file = something, _use_cert_as_auth = True
104+
connection_args = [
105+
{
106+
"server_hostname": "foo",
107+
"http_path": None,
108+
"access_token": "tok",
109+
},
110+
{
111+
"server_hostname": "foo",
112+
"http_path": None,
113+
"_tls_client_cert_file": "something",
114+
"_use_cert_as_auth": True,
115+
"access_token": None,
116+
},
117+
]
118+
119+
for args in connection_args:
120+
connection = databricks.sql.connect(**args)
121+
host, port, http_path, *_ = mock_client_class.call_args[0]
122+
self.assertEqual(args["server_hostname"], host)
123+
self.assertEqual(args["http_path"], http_path)
124+
connection.close()
125+
126+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
127+
def test_http_header_passthrough(self, mock_client_class):
128+
http_headers = [("foo", "bar")]
129+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
130+
131+
call_args = mock_client_class.call_args[0][3]
132+
self.assertIn(("foo", "bar"), call_args)
133+
134+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
135+
def test_tls_arg_passthrough(self, mock_client_class):
136+
databricks.sql.connect(
137+
**self.DUMMY_CONNECTION_ARGS,
138+
_tls_verify_hostname="hostname",
139+
_tls_trusted_ca_file="trusted ca file",
140+
_tls_client_cert_key_file="trusted client cert",
141+
_tls_client_cert_key_password="key password",
142+
)
143+
144+
kwargs = mock_client_class.call_args[1]
145+
self.assertEqual(kwargs["_tls_verify_hostname"], "hostname")
146+
self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file")
147+
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
148+
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
149+
150+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
151+
def test_useragent_header(self, mock_client_class):
152+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
153+
154+
http_headers = mock_client_class.call_args[0][3]
155+
user_agent_header = (
156+
"User-Agent",
157+
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
158+
)
159+
self.assertIn(user_agent_header, http_headers)
160+
161+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar")
162+
user_agent_header_with_entry = (
163+
"User-Agent",
164+
"{}/{} ({})".format(
165+
databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar"
166+
),
167+
)
168+
http_headers = mock_client_class.call_args[0][3]
169+
self.assertIn(user_agent_header_with_entry, http_headers)
170+
171+
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
172+
@patch("%s.client.ResultSet" % PACKAGE_NAME)
173+
def test_closing_connection_closes_commands(self, mock_result_set_class):
174+
# Test once with has_been_closed_server side, once without
89175
for closed in (True, False):
90176
with self.subTest(closed=closed):
91177
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)

0 commit comments

Comments
 (0)