Skip to content
Open
21 changes: 16 additions & 5 deletions dev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ WORKDIR ${SPARK_HOME}
ENV SPARK_VERSION=3.5.6
ENV ICEBERG_SPARK_RUNTIME_VERSION=3.5_2.12
ENV ICEBERG_VERSION=1.9.1
ENV PYICEBERG_VERSION=0.9.1

RUN curl --retry 5 -s -C - https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz -o spark-${SPARK_VERSION}-bin-hadoop3.tgz \
&& tar xzf spark-${SPARK_VERSION}-bin-hadoop3.tgz --directory /opt/spark --strip-components 1 \
Expand All @@ -55,18 +54,30 @@ RUN curl --retry 5 -s https://repo1.maven.org/maven2/org/apache/iceberg/iceberg-
RUN curl --retry 5 -s https://repo1.maven.org/maven2/org/apache/iceberg/iceberg-aws-bundle/${ICEBERG_VERSION}/iceberg-aws-bundle-${ICEBERG_VERSION}.jar \
-Lo /opt/spark/jars/iceberg-aws-bundle-${ICEBERG_VERSION}.jar

COPY spark-defaults.conf /opt/spark/conf
COPY dev/spark-defaults.conf /opt/spark/conf
ENV PATH="/opt/spark/sbin:/opt/spark/bin:${PATH}"

RUN chmod u+x /opt/spark/sbin/* && \
chmod u+x /opt/spark/bin/*

RUN pip3 install -q ipython

RUN pip3 install "pyiceberg[s3fs,hive,pyarrow]==${PYICEBERG_VERSION}"
# Copy the local pyiceberg source code and install locally
COPY pyiceberg/ /tmp/pyiceberg/pyiceberg
COPY pyproject.toml /tmp/pyiceberg/
COPY build-module.py /tmp/pyiceberg/
COPY vendor/ /tmp/pyiceberg/vendor
COPY README.md /tmp/pyiceberg/
COPY NOTICE /tmp/pyiceberg/

COPY entrypoint.sh .
COPY provision.py .
# Install pyiceberg from the copied source
RUN cd /tmp/pyiceberg && pip3 install ".[s3fs,hive,pyarrow]"

# Clean up
RUN rm -rf /tmp/pyiceberg

COPY dev/entrypoint.sh ${SPARK_HOME}/
COPY dev/provision.py ${SPARK_HOME}/

ENTRYPOINT ["./entrypoint.sh"]
CMD ["notebook"]
4 changes: 3 additions & 1 deletion dev/docker-compose-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ services:
spark-iceberg:
image: python-integration
container_name: pyiceberg-spark
build: .
build:
context: ..
dockerfile: dev/Dockerfile
networks:
iceberg_net:
depends_on:
Expand Down
56 changes: 41 additions & 15 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import getpass
import importlib
import logging
import socket
import time
from collections import namedtuple
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand All @@ -32,12 +34,14 @@
)
from urllib.parse import urlparse

from hive_metastore.ThriftHiveMetastore import Client
from hive_metastore.ttypes import (
from hive_metastore.v3.ThriftHiveMetastore import Client
from hive_metastore.v3.ttypes import (
AlreadyExistsException,
CheckLockRequest,
EnvironmentContext,
FieldSchema,
GetTableRequest,
GetTablesRequest,
InvalidOperationException,
LockComponent,
LockLevel,
Expand All @@ -51,8 +55,8 @@
StorageDescriptor,
UnlockRequest,
)
from hive_metastore.ttypes import Database as HiveDatabase
from hive_metastore.ttypes import Table as HiveTable
from hive_metastore.v3.ttypes import Database as HiveDatabase
from hive_metastore.v3.ttypes import Table as HiveTable
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport
Expand Down Expand Up @@ -141,6 +145,7 @@
DEFAULT_LOCK_CHECK_RETRIES = 4
DO_NOT_UPDATE_STATS = "DO_NOT_UPDATE_STATS"
DO_NOT_UPDATE_STATS_DEFAULT = "true"
HiveVersion = namedtuple("HiveVersion", "major minor patch")

logger = logging.getLogger(__name__)

Expand All @@ -150,6 +155,9 @@ class _HiveClient:

_transport: TTransport
_ugi: Optional[List[str]]
_hive_version: HiveVersion = HiveVersion(4, 0, 0)
_hms_v3: object
_hms_v4: object

def __init__(
self,
Expand All @@ -163,9 +171,19 @@ def __init__(
self._kerberos_service_name = kerberos_service_name
self._ugi = ugi.split(":") if ugi else None
self._transport = self._init_thrift_transport()
self.hms_v3 = importlib.import_module("hive_metastore.v3.ThriftHiveMetastore")
self.hms_v4 = importlib.import_module("hive_metastore.v4.ThriftHiveMetastore")
self._hive_version = self._get_hive_version()

def _get_hive_version(self) -> HiveVersion:
with self as open_client:
version = map(int, open_client.getVersion().split("."))
return HiveVersion(*version)

def _init_thrift_transport(self) -> TTransport:
url_parts = urlparse(self._uri)
if not url_parts.hostname or not url_parts.port:
raise ValueError("hive hostname and port must be set")
socket = TSocket.TSocket(url_parts.hostname, url_parts.port)
if not self._kerberos_auth:
return TTransport.TBufferedTransport(socket)
Expand All @@ -174,7 +192,8 @@ def _init_thrift_transport(self) -> TTransport:

def _client(self) -> Client:
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
client = Client(protocol)
hms = self.hms_v4 if all((self._hive_version.major >= 4, self._hive_version.patch > 0)) else self.hms_v3
client: Client = hms.Client(protocol)
if self._ugi:
client.set_ugi(*self._ugi)
return client
Expand Down Expand Up @@ -387,11 +406,18 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
except AlreadyExistsException as e:
raise TableAlreadyExistsError(f"Table {hive_table.dbName}.{hive_table.tableName} already exists") from e

def _get_hive_table(self, open_client: Client, database_name: str, table_name: str) -> HiveTable:
def _get_hive_table(self, open_client: Client, *, dbname: str, tbl_name: str) -> HiveTable:
try:
return open_client.get_table(dbname=database_name, tbl_name=table_name)
if all((self._client._hive_version.major >= 4, self._client._hive_version.patch > 0)):
return open_client.get_table_req(GetTableRequest(dbName=dbname, tblName=tbl_name)).table
return open_client.get_table(dbname=dbname, tbl_name=tbl_name)
except NoSuchObjectException as e:
raise NoSuchTableError(f"Table does not exists: {table_name}") from e
raise NoSuchTableError(f"Table does not exists: {tbl_name}") from e

def _get_table_objects_by_name(self, open_client: Client, *, dbname: str, tbl_names: list[str]) -> list[HiveTable]:
if all((self._client._hive_version.major >= 4, self._client._hive_version.patch > 0)):
return open_client.get_table_objects_by_name_req(GetTablesRequest(dbName=dbname, tblNames=tbl_names)).tables
return open_client.get_table_objects_by_name(dbname=dbname, tbl_names=tbl_names)

def create_table(
self,
Expand Down Expand Up @@ -435,7 +461,7 @@ def create_table(

with self._client as open_client:
self._create_hive_table(open_client, tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -465,7 +491,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
tbl = self._convert_iceberg_into_hive(staged_table)
with self._client as open_client:
self._create_hive_table(open_client, tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -538,7 +564,7 @@ def commit_table(
hive_table: Optional[HiveTable]
current_table: Optional[Table]
try:
hive_table = self._get_hive_table(open_client, database_name, table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
current_table = self._convert_hive_into_iceberg(hive_table)
except NoSuchTableError:
hive_table = None
Expand Down Expand Up @@ -612,7 +638,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
database_name, table_name = self.identifier_to_database_and_table(identifier, NoSuchTableError)

with self._client as open_client:
hive_table = self._get_hive_table(open_client, database_name, table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -661,7 +687,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U

try:
with self._client as open_client:
tbl = open_client.get_table(dbname=from_database_name, tbl_name=from_table_name)
tbl = self._get_hive_table(open_client, dbname=from_database_name, tbl_name=from_table_name)
tbl.dbName = to_database_name
tbl.tableName = to_table_name
open_client.alter_table_with_environment_context(
Expand Down Expand Up @@ -733,8 +759,8 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
with self._client as open_client:
return [
(database_name, table.tableName)
for table in open_client.get_table_objects_by_name(
dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
for table in self._get_table_objects_by_name(
open_client, dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
)
if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG
]
Expand Down
Loading