diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index 07cd79d4c7..afb327f69b 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -15,12 +15,19 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=protected-access,redefined-outer-name +import base64 import copy +import logging +import struct +import threading import uuid +from collections.abc import Generator from copy import deepcopy +from typing import Optional from unittest.mock import MagicMock, call, patch import pytest +import thrift.transport.TSocket from hive_metastore.ttypes import ( AlreadyExistsException, FieldSchema, @@ -43,6 +50,7 @@ LOCK_CHECK_RETRIES, HiveCatalog, _construct_hive_storage_descriptor, + _HiveClient, ) from pyiceberg.exceptions import ( NamespaceAlreadyExistsError, @@ -183,6 +191,61 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase: ) +class SaslServer(threading.Thread): + def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None: + super().__init__() + self.daemon = True + self._socket = socket + self._response = response + self._port = None + self._port_bound = threading.Event() + + def run(self) -> None: + self._socket.listen() + + try: + address = self._socket.handle.getsockname() + # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are + # 4-tuples (host, port, ...), i.e. port is always at index 1. + _host, self._port, *_ = address + finally: + self._port_bound.set() + + # Accept connections and respond to each connection with the same message. + # The responsibility for closing the connection is on the client + while True: + try: + client = self._socket.accept() + if client: + client.write(self._response) + client.flush() + except Exception: + logging.exception( + "An error occurred while responding to client", + ) + + @property + def port(self) -> Optional[int]: + self._port_bound.wait() + return self._port + + def close(self) -> None: + self._socket.close() + + +@pytest.fixture(scope="session") +def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]: + server = SaslServer( + # Port 0 means pick any available port. + socket=thrift.transport.TSocket.TServerSocket(port=0), + # Always return a message with status 5 (COMPLETE). + response=struct.pack(">BI", 5, 0), + ) + server.start() + yield f"thrift://localhost:{server.port}" + server.close() + + def test_no_uri_supplied() -> None: with pytest.raises(KeyError): HiveCatalog("production") @@ -1239,3 +1302,45 @@ def test_create_hive_client_failure() -> None: with pytest.raises(Exception, match="Connection failed"): HiveCatalog._create_hive_client(properties) assert mock_hive_client.call_count == 2 + + +def test_create_hive_client_with_kerberos( + kerberized_hive_metastore_fake_url: str, +) -> None: + properties = { + "uri": kerberized_hive_metastore_fake_url, + "ugi": "user", + "kerberos_auth": True, + } + client = HiveCatalog._create_hive_client(properties) + assert client is not None + + +def test_create_hive_client_with_kerberos_using_context_manager( + kerberized_hive_metastore_fake_url: str, +) -> None: + client = _HiveClient( + uri=kerberized_hive_metastore_fake_url, + kerberos_auth=True, + ) + with ( + patch( + "puresasl.mechanisms.kerberos.authGSSClientStep", + return_value=None, + ), + patch( + "puresasl.mechanisms.kerberos.authGSSClientResponse", + return_value=base64.b64encode(b"Some Response"), + ), + patch( + "puresasl.mechanisms.GSSAPIMechanism.complete", + return_value=True, + ), + ): + with client as open_client: + assert open_client._iprot.trans.isOpen() + + # Use the context manager a second time to see if + # closing and re-opening work as expected. + with client as open_client: + assert open_client._iprot.trans.isOpen()