Skip to content

Commit 0c8929d

Browse files
NiallEgansusodapop
authored andcommitted
Fix SSL arg bug
This PR fixes `__tls_client_cert_key_file` => `_tls_client_cert_key_file` and adds some rudimentary unit tests for this behaviour on PyHive. I also removed the accidentaly commit of the package
1 parent 426f3e0 commit 0c8929d

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
7070
ssl_context.verify_mode = CERT_REQUIRED
7171

7272
tls_client_cert_file = kwargs.get("_tls_client_cert_file")
73-
tls_client_cert_key_file = kwargs.get("__tls_client_cert_key_file")
73+
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
7474
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
7575
if tls_client_cert_file:
7676
ssl_context.load_cert_chain(

cmdexec/clients/python/tests/docker-run-unit.sh

Lines changed: 0 additions & 5 deletions
This file was deleted.

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import unittest
55
from unittest.mock import patch, MagicMock, Mock
6+
from ssl import CERT_NONE, CERT_REQUIRED
67

78
import pyarrow
89

@@ -128,6 +129,52 @@ def test_headers_are_set(self, t_http_client_class):
128129
ThriftBackend("foo", 123, "bar", [("header", "value")])
129130
t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"})
130131

132+
@patch("thrift.transport.THttpClient.THttpClient")
133+
@patch("databricks.sql.thrift_backend.create_default_context")
134+
def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class):
135+
mock_cert_key_file = Mock()
136+
mock_cert_key_password = Mock()
137+
mock_trusted_ca_file = Mock()
138+
mock_cert_file = Mock()
139+
140+
ThriftBackend(
141+
"foo",
142+
123,
143+
"bar", [],
144+
_tls_client_cert_file=mock_cert_file,
145+
_tls_client_cert_key_file=mock_cert_key_file,
146+
_tls_client_cert_key_password=mock_cert_key_password,
147+
_tls_trusted_ca_file=mock_trusted_ca_file)
148+
149+
mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file)
150+
mock_ssl_context = mock_create_default_context.return_value
151+
mock_ssl_context.load_cert_chain.assert_called_once_with(
152+
certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password)
153+
self.assertTrue(mock_ssl_context.check_hostname)
154+
self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED)
155+
self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context)
156+
157+
@patch("thrift.transport.THttpClient.THttpClient")
158+
@patch("databricks.sql.thrift_backend.create_default_context")
159+
def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class):
160+
ThriftBackend("foo", 123, "bar", [], _tls_no_verify=True)
161+
162+
mock_ssl_context = mock_create_default_context.return_value
163+
self.assertFalse(mock_ssl_context.check_hostname)
164+
self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE)
165+
self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context)
166+
167+
@patch("thrift.transport.THttpClient.THttpClient")
168+
@patch("databricks.sql.thrift_backend.create_default_context")
169+
def test_tls_verify_hostname_is_respected(self, mock_create_default_context,
170+
t_http_client_class):
171+
ThriftBackend("foo", 123, "bar", [], _tls_verify_hostname=False)
172+
173+
mock_ssl_context = mock_create_default_context.return_value
174+
self.assertFalse(mock_ssl_context.check_hostname)
175+
self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED)
176+
self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context)
177+
131178
@patch("thrift.transport.THttpClient.THttpClient")
132179
def test_port_and_host_are_respected(self, t_http_client_class):
133180
ThriftBackend("hostname", 123, "path_value", [])

cmdexec/clients/python/tests/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import databricks.sql.client as client
1010
from databricks.sql import InterfaceError, DatabaseError, Error
1111

12-
from cmdexec.clients.python.tests.test_fetches import FetchTests
13-
from cmdexec.clients.python.tests.test_thrift_backend import ThriftBackendTestSuite
12+
from test_fetches import FetchTests
13+
from test_thrift_backend import ThriftBackendTestSuite
1414

1515

1616
class ClientTestSuite(unittest.TestCase):

0 commit comments

Comments
 (0)