|
3 | 3 | import itertools |
4 | 4 | import unittest |
5 | 5 | from unittest.mock import patch, MagicMock, Mock |
| 6 | +from ssl import CERT_NONE, CERT_REQUIRED |
6 | 7 |
|
7 | 8 | import pyarrow |
8 | 9 |
|
@@ -128,6 +129,52 @@ def test_headers_are_set(self, t_http_client_class): |
128 | 129 | ThriftBackend("foo", 123, "bar", [("header", "value")]) |
129 | 130 | t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) |
130 | 131 |
|
| 132 | + @patch("thrift.transport.THttpClient.THttpClient") |
| 133 | + @patch("databricks.sql.thrift_backend.create_default_context") |
| 134 | + def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): |
| 135 | + mock_cert_key_file = Mock() |
| 136 | + mock_cert_key_password = Mock() |
| 137 | + mock_trusted_ca_file = Mock() |
| 138 | + mock_cert_file = Mock() |
| 139 | + |
| 140 | + ThriftBackend( |
| 141 | + "foo", |
| 142 | + 123, |
| 143 | + "bar", [], |
| 144 | + _tls_client_cert_file=mock_cert_file, |
| 145 | + _tls_client_cert_key_file=mock_cert_key_file, |
| 146 | + _tls_client_cert_key_password=mock_cert_key_password, |
| 147 | + _tls_trusted_ca_file=mock_trusted_ca_file) |
| 148 | + |
| 149 | + mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) |
| 150 | + mock_ssl_context = mock_create_default_context.return_value |
| 151 | + mock_ssl_context.load_cert_chain.assert_called_once_with( |
| 152 | + certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password) |
| 153 | + self.assertTrue(mock_ssl_context.check_hostname) |
| 154 | + self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) |
| 155 | + self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) |
| 156 | + |
| 157 | + @patch("thrift.transport.THttpClient.THttpClient") |
| 158 | + @patch("databricks.sql.thrift_backend.create_default_context") |
| 159 | + def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): |
| 160 | + ThriftBackend("foo", 123, "bar", [], _tls_no_verify=True) |
| 161 | + |
| 162 | + mock_ssl_context = mock_create_default_context.return_value |
| 163 | + self.assertFalse(mock_ssl_context.check_hostname) |
| 164 | + self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) |
| 165 | + self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) |
| 166 | + |
| 167 | + @patch("thrift.transport.THttpClient.THttpClient") |
| 168 | + @patch("databricks.sql.thrift_backend.create_default_context") |
| 169 | + def test_tls_verify_hostname_is_respected(self, mock_create_default_context, |
| 170 | + t_http_client_class): |
| 171 | + ThriftBackend("foo", 123, "bar", [], _tls_verify_hostname=False) |
| 172 | + |
| 173 | + mock_ssl_context = mock_create_default_context.return_value |
| 174 | + self.assertFalse(mock_ssl_context.check_hostname) |
| 175 | + self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) |
| 176 | + self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) |
| 177 | + |
131 | 178 | @patch("thrift.transport.THttpClient.THttpClient") |
132 | 179 | def test_port_and_host_are_respected(self, t_http_client_class): |
133 | 180 | ThriftBackend("hostname", 123, "path_value", []) |
|
0 commit comments