Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,redefined-outer-name
import base64
import copy
import logging
import struct
import threading
import uuid
from collections.abc import Generator
from copy import deepcopy
from typing import Optional
from unittest.mock import MagicMock, call, patch

import pytest
import thrift.transport.TSocket
from hive_metastore.ttypes import (
AlreadyExistsException,
FieldSchema,
Expand All @@ -43,6 +50,7 @@
LOCK_CHECK_RETRIES,
HiveCatalog,
_construct_hive_storage_descriptor,
_HiveClient,
)
from pyiceberg.exceptions import (
NamespaceAlreadyExistsError,
Expand Down Expand Up @@ -183,6 +191,61 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase:
)


class SaslServer(threading.Thread):
def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None:
super().__init__()
self.daemon = True
self._socket = socket
self._response = response
self._port = None
self._port_bound = threading.Event()

def run(self) -> None:
self._socket.listen()

try:
address = self._socket.handle.getsockname()
# AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
# 4-tuples (host, port, ...), i.e. port is always at index 1.
_host, self._port, *_ = address
finally:
self._port_bound.set()

# Accept connections and respond to each connection with the same message.
# The responsibility for closing the connection is on the client
while True:
try:
client = self._socket.accept()
if client:
client.write(self._response)
client.flush()
except Exception:
logging.exception(
"An error occurred while responding to client",
)

@property
def port(self) -> Optional[int]:
self._port_bound.wait()
return self._port

def close(self) -> None:
self._socket.close()


@pytest.fixture(scope="session")
def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]:
server = SaslServer(
# Port 0 means pick any available port.
socket=thrift.transport.TSocket.TServerSocket(port=0),
# Always return a message with status 5 (COMPLETE).
response=struct.pack(">BI", 5, 0),
)
server.start()
yield f"thrift://localhost:{server.port}"
server.close()


def test_no_uri_supplied() -> None:
with pytest.raises(KeyError):
HiveCatalog("production")
Expand Down Expand Up @@ -1239,3 +1302,45 @@ def test_create_hive_client_failure() -> None:
with pytest.raises(Exception, match="Connection failed"):
HiveCatalog._create_hive_client(properties)
assert mock_hive_client.call_count == 2


def test_create_hive_client_with_kerberos(
kerberized_hive_metastore_fake_url: str,
) -> None:
properties = {
"uri": kerberized_hive_metastore_fake_url,
"ugi": "user",
"kerberos_auth": True,
}
client = HiveCatalog._create_hive_client(properties)
assert client is not None


def test_create_hive_client_with_kerberos_using_context_manager(
kerberized_hive_metastore_fake_url: str,
) -> None:
client = _HiveClient(
uri=kerberized_hive_metastore_fake_url,
kerberos_auth=True,
)
with (
patch(
"puresasl.mechanisms.kerberos.authGSSClientStep",
return_value=None,
),
patch(
"puresasl.mechanisms.kerberos.authGSSClientResponse",
return_value=base64.b64encode(b"Some Response"),
),
patch(
"puresasl.mechanisms.GSSAPIMechanism.complete",
return_value=True,
),
):
with client as open_client:
assert open_client._iprot.trans.isOpen()

# Use the context manager a second time to see if
# closing and re-opening work as expected.
with client as open_client:
assert open_client._iprot.trans.isOpen()
Loading