|
18 | 18 | import logging |
19 | 19 | import socket |
20 | 20 | import time |
| 21 | +from functools import cached_property |
21 | 22 | from types import TracebackType |
22 | 23 | from typing import ( |
23 | 24 | TYPE_CHECKING, |
|
142 | 143 | class _HiveClient: |
143 | 144 | """Helper class to nicely open and close the transport.""" |
144 | 145 |
|
145 | | - _transport: TTransport |
146 | | - _client: Client |
147 | | - _ugi: Optional[List[str]] |
148 | | - |
149 | 146 | def __init__(self, uri: str, ugi: Optional[str] = None, kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT): |
150 | 147 | self._uri = uri |
151 | 148 | self._kerberos_auth = kerberos_auth |
152 | 149 | self._ugi = ugi.split(":") if ugi else None |
153 | 150 |
|
154 | | - self._init_thrift_client() |
155 | | - |
156 | | - def _init_thrift_client(self) -> None: |
| 151 | + def _init_thrift_transport(self) -> TTransport: |
157 | 152 | url_parts = urlparse(self._uri) |
158 | | - |
159 | 153 | socket = TSocket.TSocket(url_parts.hostname, url_parts.port) |
160 | | - |
161 | 154 | if not self._kerberos_auth: |
162 | | - self._transport = TTransport.TBufferedTransport(socket) |
| 155 | + return TTransport.TBufferedTransport(socket) |
163 | 156 | else: |
164 | | - self._transport = TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") |
| 157 | + return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") |
165 | 158 |
|
| 159 | + @cached_property |
| 160 | + def _client(self) -> Client: |
| 161 | + self._transport = self._init_thrift_transport() |
166 | 162 | protocol = TBinaryProtocol.TBinaryProtocol(self._transport) |
167 | | - |
168 | | - self._client = Client(protocol) |
169 | | - |
170 | | - def __enter__(self) -> Client: |
171 | | - """Ensure transport is open before returning the client.""" |
172 | | - if self._transport is None or not self._transport.isOpen(): |
173 | | - self._init_thrift_client() # Reinitialize transport if closed |
174 | | - |
175 | | - if not self._transport.isOpen(): |
176 | | - self._transport.open() |
177 | | - |
| 163 | + client = Client(protocol) |
178 | 164 | if self._ugi: |
179 | | - self._client.set_ugi(*self._ugi) |
| 165 | + client.set_ugi(*self._ugi) |
| 166 | + return client |
180 | 167 |
|
| 168 | + def __enter__(self) -> Client: |
| 169 | + """Reinitialize transport if was closed.""" |
| 170 | + if self._transport and not self._transport.isOpen(): |
| 171 | + self._transport = self._init_thrift_transport() |
181 | 172 | return self._client |
182 | 173 |
|
183 | 174 | def __exit__( |
184 | 175 | self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] |
185 | 176 | ) -> None: |
186 | 177 | """Close transport if it was opened.""" |
187 | | - if self._transport: |
| 178 | + if self._transport and self._transport.isOpen(): |
188 | 179 | self._transport.close() |
189 | | - self._transport = None # Reset transport so a new one is created next time |
190 | | - self._client = None |
191 | 180 |
|
192 | 181 |
|
193 | 182 | def _construct_hive_storage_descriptor( |
|
0 commit comments