Skip to content

Commit a8f86ce

Browse files
committed
re-init transport
1 parent 7ddb751 commit a8f86ce

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

pyiceberg/catalog/hive.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import socket
2020
import time
21+
from functools import cached_property
2122
from types import TracebackType
2223
from typing import (
2324
TYPE_CHECKING,
@@ -141,52 +142,40 @@
141142
class _HiveClient:
142143
"""Helper class to nicely open and close the transport."""
143144

144-
_transport: TTransport
145-
_client: Client
146-
_ugi: Optional[List[str]]
147-
148145
def __init__(self, uri: str, ugi: Optional[str] = None, kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT):
149146
self._uri = uri
150147
self._kerberos_auth = kerberos_auth
151148
self._ugi = ugi.split(":") if ugi else None
152149

153-
self._init_thrift_client()
154-
155-
def _init_thrift_client(self) -> None:
150+
def _init_thrift_transport(self) -> TTransport:
156151
url_parts = urlparse(self._uri)
157-
158152
socket = TSocket.TSocket(url_parts.hostname, url_parts.port)
159-
160153
if not self._kerberos_auth:
161-
self._transport = TTransport.TBufferedTransport(socket)
154+
return TTransport.TBufferedTransport(socket)
162155
else:
163-
self._transport = TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive")
156+
return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive")
164157

158+
@cached_property
159+
def _client(self) -> Client:
160+
self._transport = self._init_thrift_transport()
165161
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
166-
167-
self._client = Client(protocol)
168-
169-
def __enter__(self) -> Client:
170-
"""Ensure transport is open before returning the client."""
171-
if self._transport is None or not self._transport.isOpen():
172-
self._init_thrift_client() # Reinitialize transport if closed
173-
174-
if not self._transport.isOpen():
175-
self._transport.open()
176-
162+
client = Client(protocol)
177163
if self._ugi:
178-
self._client.set_ugi(*self._ugi)
164+
client.set_ugi(*self._ugi)
165+
return client
179166

167+
def __enter__(self) -> Client:
168+
"""Reinitialize transport if was closed."""
169+
if self._transport and not self._transport.isOpen():
170+
self._transport = self._init_thrift_transport()
180171
return self._client
181172

182173
def __exit__(
183174
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
184175
) -> None:
185176
"""Close transport if it was opened."""
186-
if self._transport:
177+
if self._transport and self._transport.isOpen():
187178
self._transport.close()
188-
self._transport = None # Reset transport so a new one is created next time
189-
self._client = None
190179

191180

192181
def _construct_hive_storage_descriptor(

0 commit comments

Comments
 (0)