|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=protected-access,redefined-outer-name |
| 18 | +import base64 |
18 | 19 | import copy |
| 20 | +import logging |
| 21 | +import struct |
| 22 | +import threading |
19 | 23 | import uuid |
| 24 | +from collections.abc import Generator |
20 | 25 | from copy import deepcopy |
| 26 | +from typing import Optional |
21 | 27 | from unittest.mock import MagicMock, call, patch |
22 | 28 |
|
23 | 29 | import pytest |
| 30 | +import thrift.transport.TSocket |
24 | 31 | from hive_metastore.ttypes import ( |
25 | 32 | AlreadyExistsException, |
26 | 33 | FieldSchema, |
|
43 | 50 | LOCK_CHECK_RETRIES, |
44 | 51 | HiveCatalog, |
45 | 52 | _construct_hive_storage_descriptor, |
| 53 | + _HiveClient, |
46 | 54 | ) |
47 | 55 | from pyiceberg.exceptions import ( |
48 | 56 | NamespaceAlreadyExistsError, |
@@ -183,6 +191,61 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase: |
183 | 191 | ) |
184 | 192 |
|
185 | 193 |
|
| 194 | +class SaslServer(threading.Thread): |
| 195 | + def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None: |
| 196 | + super().__init__() |
| 197 | + self.daemon = True |
| 198 | + self._socket = socket |
| 199 | + self._response = response |
| 200 | + self._port = None |
| 201 | + self._port_bound = threading.Event() |
| 202 | + |
| 203 | + def run(self) -> None: |
| 204 | + self._socket.listen() |
| 205 | + |
| 206 | + try: |
| 207 | + address = self._socket.handle.getsockname() |
| 208 | + # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are |
| 209 | + # 4-tuples (host, port, ...), i.e. port is always at index 1. |
| 210 | + _host, self._port, *_ = address |
| 211 | + finally: |
| 212 | + self._port_bound.set() |
| 213 | + |
| 214 | + # Accept connections and respond to each connection with the same message. |
| 215 | + # The responsibility for closing the connection is on the client |
| 216 | + while True: |
| 217 | + try: |
| 218 | + client = self._socket.accept() |
| 219 | + if client: |
| 220 | + client.write(self._response) |
| 221 | + client.flush() |
| 222 | + except Exception: |
| 223 | + logging.exception( |
| 224 | + "An error occurred while responding to client", |
| 225 | + ) |
| 226 | + |
| 227 | + @property |
| 228 | + def port(self) -> Optional[int]: |
| 229 | + self._port_bound.wait() |
| 230 | + return self._port |
| 231 | + |
| 232 | + def close(self) -> None: |
| 233 | + self._socket.close() |
| 234 | + |
| 235 | + |
| 236 | +@pytest.fixture(scope="session") |
| 237 | +def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]: |
| 238 | + server = SaslServer( |
| 239 | + # Port 0 means pick any available port. |
| 240 | + socket=thrift.transport.TSocket.TServerSocket(port=0), |
| 241 | + # Always return a message with status 5 (COMPLETE). |
| 242 | + response=struct.pack(">BI", 5, 0), |
| 243 | + ) |
| 244 | + server.start() |
| 245 | + yield f"thrift://localhost:{server.port}" |
| 246 | + server.close() |
| 247 | + |
| 248 | + |
186 | 249 | def test_no_uri_supplied() -> None: |
187 | 250 | with pytest.raises(KeyError): |
188 | 251 | HiveCatalog("production") |
@@ -1239,3 +1302,45 @@ def test_create_hive_client_failure() -> None: |
1239 | 1302 | with pytest.raises(Exception, match="Connection failed"): |
1240 | 1303 | HiveCatalog._create_hive_client(properties) |
1241 | 1304 | assert mock_hive_client.call_count == 2 |
| 1305 | + |
| 1306 | + |
| 1307 | +def test_create_hive_client_with_kerberos( |
| 1308 | + kerberized_hive_metastore_fake_url: str, |
| 1309 | +) -> None: |
| 1310 | + properties = { |
| 1311 | + "uri": kerberized_hive_metastore_fake_url, |
| 1312 | + "ugi": "user", |
| 1313 | + "kerberos_auth": True, |
| 1314 | + } |
| 1315 | + client = HiveCatalog._create_hive_client(properties) |
| 1316 | + assert client is not None |
| 1317 | + |
| 1318 | + |
| 1319 | +def test_create_hive_client_with_kerberos_using_context_manager( |
| 1320 | + kerberized_hive_metastore_fake_url: str, |
| 1321 | +) -> None: |
| 1322 | + client = _HiveClient( |
| 1323 | + uri=kerberized_hive_metastore_fake_url, |
| 1324 | + kerberos_auth=True, |
| 1325 | + ) |
| 1326 | + with ( |
| 1327 | + patch( |
| 1328 | + "puresasl.mechanisms.kerberos.authGSSClientStep", |
| 1329 | + return_value=None, |
| 1330 | + ), |
| 1331 | + patch( |
| 1332 | + "puresasl.mechanisms.kerberos.authGSSClientResponse", |
| 1333 | + return_value=base64.b64encode(b"Some Response"), |
| 1334 | + ), |
| 1335 | + patch( |
| 1336 | + "puresasl.mechanisms.GSSAPIMechanism.complete", |
| 1337 | + return_value=True, |
| 1338 | + ), |
| 1339 | + ): |
| 1340 | + with client as open_client: |
| 1341 | + assert open_client._iprot.trans.isOpen() |
| 1342 | + |
| 1343 | + # Use the context manager a second time to see if |
| 1344 | + # closing and re-opening work as expected. |
| 1345 | + with client as open_client: |
| 1346 | + assert open_client._iprot.trans.isOpen() |
0 commit comments