Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import base64
import datetime as py_datetime
import importlib
import struct
import types
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
Expand All @@ -28,6 +30,7 @@
import mmh3
from pydantic import Field, PositiveInt, PrivateAttr

from pyiceberg.exceptions import NotInstalledError
from pyiceberg.expressions import (
BoundEqualTo,
BoundGreaterThan,
Expand Down Expand Up @@ -106,6 +109,17 @@
TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE)


def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.ModuleType:
try:
return importlib.import_module(module_name)
except ImportError:
if extras_name:
msg = f'{module_name} needs to be installed. pip install "pyiceberg[{extras_name}]"'
else:
msg = f"{module_name} needs to be installed."
raise NotInstalledError(msg) from None


def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
"""Small helper to upwrap the value from the literal, and wrap it again."""
return literal(func(lit.value))
Expand Down Expand Up @@ -382,8 +396,7 @@ def __repr__(self) -> str:
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform

pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
Expand Down Expand Up @@ -509,9 +522,8 @@ def __repr__(self) -> str:
return "YearTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
from pyiceberg_core import transform as pyiceberg_core_transform

pa = _try_import("pyarrow")
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.year, expected_type=pa.int32())


Expand Down Expand Up @@ -570,8 +582,8 @@ def __repr__(self) -> str:
return "MonthTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
from pyiceberg_core import transform as pyiceberg_core_transform
pa = _try_import("pyarrow")
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform

return _pyiceberg_transform_wrapper(pyiceberg_core_transform.month, expected_type=pa.int32())

Expand Down Expand Up @@ -639,8 +651,8 @@ def __repr__(self) -> str:
return "DayTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
from pyiceberg_core import transform as pyiceberg_core_transform
pa = _try_import("pyarrow", extras_name="pyarrow")
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform

return _pyiceberg_transform_wrapper(pyiceberg_core_transform.day, expected_type=pa.int32())

Expand Down Expand Up @@ -692,7 +704,7 @@ def __repr__(self) -> str:
return "HourTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform

return _pyiceberg_transform_wrapper(pyiceberg_core_transform.hour)

Expand Down Expand Up @@ -915,7 +927,7 @@ def __repr__(self) -> str:
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform

return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

Expand Down
14 changes: 14 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
RootModel,
WithJsonSchema,
)
from pytest_mock import MockFixture

from pyiceberg.exceptions import NotInstalledError
from pyiceberg.expressions import (
AlwaysFalse,
BooleanExpression,
Expand Down Expand Up @@ -1668,3 +1670,15 @@ def test_truncate_pyarrow_transforms(
) -> None:
transform: Transform[Any, Any] = TruncateTransform(width=width)
assert expected == transform.pyarrow_transform(source_type)(input_arr)


@pytest.mark.parametrize(
"transform", [BucketTransform(num_buckets=5), TruncateTransform(width=5), YearTransform(), MonthTransform(), DayTransform()]
)
def test_calling_pyarrow_transform_without_pyiceberg_core_installed_correctly_raises_not_imported_error(
transform, mocker: MockFixture
) -> None:
mocker.patch.dict("sys.modules", {"pyiceberg_core": None})

with pytest.raises(NotInstalledError):
transform.pyarrow_transform(StringType())