Skip to content

Commit 6793e64

Browse files
committed
Add tests for kerberized hive client
1 parent 0d56a3b commit 6793e64

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

tests/catalog/test_hive.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=protected-access,redefined-outer-name
18+
import base64
1819
import copy
20+
import logging
21+
import struct
22+
import threading
1923
import uuid
24+
from collections.abc import Generator
2025
from copy import deepcopy
26+
from typing import Optional
2127
from unittest.mock import MagicMock, call, patch
2228

2329
import pytest
30+
import thrift.transport.TSocket
2431
from hive_metastore.ttypes import (
2532
AlreadyExistsException,
2633
FieldSchema,
@@ -43,6 +50,7 @@
4350
LOCK_CHECK_RETRIES,
4451
HiveCatalog,
4552
_construct_hive_storage_descriptor,
53+
_HiveClient,
4654
)
4755
from pyiceberg.exceptions import (
4856
NamespaceAlreadyExistsError,
@@ -183,6 +191,61 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase:
183191
)
184192

185193

194+
class SaslServer(threading.Thread):
195+
def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None:
196+
super().__init__()
197+
self.daemon = True
198+
self._socket = socket
199+
self._response = response
200+
self._port = None
201+
self._port_bound = threading.Event()
202+
203+
def run(self) -> None:
204+
self._socket.listen()
205+
206+
try:
207+
address = self._socket.handle.getsockname()
208+
# AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
209+
# 4-tuples (host, port, ...), i.e. port is always at index 1.
210+
_host, self._port, *_ = address
211+
finally:
212+
self._port_bound.set()
213+
214+
# Accept connections and respond to each connection with the same message.
215+
# The responsibility for closing the connection is on the client
216+
while True:
217+
try:
218+
client = self._socket.accept()
219+
if client:
220+
client.write(self._response)
221+
client.flush()
222+
except Exception:
223+
logging.exception(
224+
"An error occurred while responding to client",
225+
)
226+
227+
@property
228+
def port(self) -> Optional[int]:
229+
self._port_bound.wait()
230+
return self._port
231+
232+
def close(self) -> None:
233+
self._socket.close()
234+
235+
236+
@pytest.fixture(scope="session")
237+
def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]:
238+
server = SaslServer(
239+
# Port 0 means pick any available port.
240+
socket=thrift.transport.TSocket.TServerSocket(port=0),
241+
# Always return a message with status 5 (COMPLETE).
242+
response=struct.pack(">BI", 5, 0),
243+
)
244+
server.start()
245+
yield f"thrift://localhost:{server.port}"
246+
server.close()
247+
248+
186249
def test_no_uri_supplied() -> None:
187250
with pytest.raises(KeyError):
188251
HiveCatalog("production")
@@ -1239,3 +1302,45 @@ def test_create_hive_client_failure() -> None:
12391302
with pytest.raises(Exception, match="Connection failed"):
12401303
HiveCatalog._create_hive_client(properties)
12411304
assert mock_hive_client.call_count == 2
1305+
1306+
1307+
def test_create_hive_client_with_kerberos(
1308+
kerberized_hive_metastore_fake_url: str,
1309+
) -> None:
1310+
properties = {
1311+
"uri": kerberized_hive_metastore_fake_url,
1312+
"ugi": "user",
1313+
"kerberos_auth": True,
1314+
}
1315+
client = HiveCatalog._create_hive_client(properties)
1316+
assert client is not None
1317+
1318+
1319+
def test_create_hive_client_with_kerberos_using_context_manager(
1320+
kerberized_hive_metastore_fake_url: str,
1321+
) -> None:
1322+
client = _HiveClient(
1323+
uri=kerberized_hive_metastore_fake_url,
1324+
kerberos_auth=True,
1325+
)
1326+
with (
1327+
patch(
1328+
"puresasl.mechanisms.kerberos.authGSSClientStep",
1329+
return_value=None,
1330+
),
1331+
patch(
1332+
"puresasl.mechanisms.kerberos.authGSSClientResponse",
1333+
return_value=base64.b64encode(b"Some Response"),
1334+
),
1335+
patch(
1336+
"puresasl.mechanisms.GSSAPIMechanism.complete",
1337+
return_value=True,
1338+
),
1339+
):
1340+
with client as open_client:
1341+
assert open_client._iprot.trans.isOpen()
1342+
1343+
# Use the context manager a second time to see if
1344+
# closing and re-opening work as expected.
1345+
with client as open_client:
1346+
assert open_client._iprot.trans.isOpen()

0 commit comments

Comments
 (0)