Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ jobs:
python-version: ${{ matrix.python }}
cache: poetry
cache-dependency-path: ./poetry.lock
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libkrb5-dev # for kerberos
- name: Install
run: make install-dependencies
- name: Linters
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/python-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 2
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libkrb5-dev # for kerberos
- name: Install
run: make install
- name: Run integration tests
Expand Down
47 changes: 31 additions & 16 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
import socket
import time
from functools import cached_property
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -160,7 +159,6 @@ def _init_thrift_transport(self) -> TTransport:
else:
return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive")

@cached_property
def _client(self) -> Client:
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
client = Client(protocol)
Expand All @@ -173,11 +171,11 @@ def __enter__(self) -> Client:
if not self._transport.isOpen():
try:
self._transport.open()
except TTransport.TTransportException:
except (TypeError, TTransport.TTransportException):
# reinitialize _transport
self._transport = self._init_thrift_transport()
self._transport.open()
return self._client
return self._client() # recreate the client

def __exit__(
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ cachetools = "^5.5.0"
pyiceberg-core = { version = "^0.4.0", optional = true }
polars = { version = "^1.21.0", optional = true }
thrift-sasl = { version = ">=0.4.3", optional = true }
kerberos = {version = "^1.3.1", optional = true}

[tool.poetry.group.dev.dependencies]
pytest = "7.4.4"
Expand Down Expand Up @@ -295,7 +296,7 @@ daft = ["getdaft"]
polars = ["polars"]
snappy = ["python-snappy"]
hive = ["thrift"]
hive-kerberos = ["thrift", "thrift_sasl"]
hive-kerberos = ["thrift", "thrift_sasl", "kerberos"]
s3fs = ["s3fs"]
glue = ["boto3", "mypy-boto3-glue"]
adlfs = ["adlfs"]
Expand Down
103 changes: 103 additions & 0 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,redefined-outer-name
import base64
import copy
import struct
import threading
import uuid
from collections.abc import Generator
from copy import deepcopy
from typing import Optional
from unittest.mock import MagicMock, call, patch

import pytest
import thrift.transport.TSocket
from hive_metastore.ttypes import (
AlreadyExistsException,
FieldSchema,
Expand All @@ -38,11 +44,13 @@

from pyiceberg.catalog import PropertiesUpdateSummary
from pyiceberg.catalog.hive import (
HIVE_KERBEROS_AUTH,
LOCK_CHECK_MAX_WAIT_TIME,
LOCK_CHECK_MIN_WAIT_TIME,
LOCK_CHECK_RETRIES,
HiveCatalog,
_construct_hive_storage_descriptor,
_HiveClient,
)
from pyiceberg.exceptions import (
NamespaceAlreadyExistsError,
Expand Down Expand Up @@ -183,6 +191,59 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase:
)


class SaslServer(threading.Thread):
def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None:
super().__init__()
self.daemon = True
self._socket = socket
self._response = response
self._port = None
self._port_bound = threading.Event()

def run(self) -> None:
self._socket.listen()

try:
address = self._socket.handle.getsockname()
# AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
# 4-tuples (host, port, ...), i.e. port is always at index 1.
_host, self._port, *_ = address
finally:
self._port_bound.set()

# Accept connections and respond to each connection with the same message.
# The responsibility for closing the connection is on the client
while True:
try:
client = self._socket.accept()
if client:
client.write(self._response)
client.flush()
except Exception:
pass

@property
def port(self) -> Optional[int]:
self._port_bound.wait()
return self._port

def close(self) -> None:
self._socket.close()


@pytest.fixture(scope="session")
def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]:
server = SaslServer(
# Port 0 means pick any available port.
socket=thrift.transport.TSocket.TServerSocket(port=0),
# Always return a message with status 5 (COMPLETE).
response=struct.pack(">BI", 5, 0),
)
server.start()
yield f"thrift://localhost:{server.port}"
server.close()


def test_no_uri_supplied() -> None:
with pytest.raises(KeyError):
HiveCatalog("production")
Expand Down Expand Up @@ -1239,3 +1300,45 @@ def test_create_hive_client_failure() -> None:
with pytest.raises(Exception, match="Connection failed"):
HiveCatalog._create_hive_client(properties)
assert mock_hive_client.call_count == 2


def test_create_hive_client_with_kerberos(
kerberized_hive_metastore_fake_url: str,
) -> None:
properties = {
"uri": kerberized_hive_metastore_fake_url,
"ugi": "user",
HIVE_KERBEROS_AUTH: "true",
}
client = HiveCatalog._create_hive_client(properties)
assert client is not None


def test_create_hive_client_with_kerberos_using_context_manager(
kerberized_hive_metastore_fake_url: str,
) -> None:
client = _HiveClient(
uri=kerberized_hive_metastore_fake_url,
kerberos_auth=True,
)
with (
patch(
"puresasl.mechanisms.kerberos.authGSSClientStep",
return_value=None,
),
patch(
"puresasl.mechanisms.kerberos.authGSSClientResponse",
return_value=base64.b64encode(b"Some Response"),
),
patch(
"puresasl.mechanisms.GSSAPIMechanism.complete",
return_value=True,
),
):
with client as open_client:
assert open_client._iprot.trans.isOpen()

# Use the context manager a second time to see if
# closing and re-opening work as expected.
with client as open_client:
assert open_client._iprot.trans.isOpen()