Skip to content

Commit b7e0713

Browse files
Initial partition evolution
1 parent 567ec49 commit b7e0713

File tree

2 files changed

+339
-3
lines changed

2 files changed

+339
-3
lines changed

pyiceberg/partitioning.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
from __future__ import annotations
18-
18+
from abc import ABC, abstractmethod
1919
from functools import cached_property
2020
from typing import (
2121
Any,
22+
cast,
2223
Dict,
24+
Generic,
2325
List,
2426
Optional,
2527
Tuple,
28+
TypeVar
2629
)
2730

2831
from pydantic import (
@@ -34,7 +37,19 @@
3437
from typing_extensions import Annotated
3538

3639
from pyiceberg.schema import Schema
37-
from pyiceberg.transforms import Transform, parse_transform
40+
from pyiceberg.transforms import (
41+
BucketTransform,
42+
DayTransform,
43+
IdentityTransform,
44+
HourTransform,
45+
Transform,
46+
TruncateTransform,
47+
UnknownTransform,
48+
VoidTransform,
49+
YearTransform,
50+
parse_transform
51+
)
52+
3853
from pyiceberg.typedef import IcebergBaseModel
3954
from pyiceberg.types import NestedField, StructType
4055

@@ -215,3 +230,118 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
215230
)
216231
)
217232
return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID)
233+
234+
T = TypeVar("T")
235+
class PartitionSpecVisitor(Generic[T], ABC):
236+
@abstractmethod
237+
def identity(field_id: int, source_name: str, source_id: int) -> T:
238+
"""
239+
Visit identity partition field
240+
"""
241+
242+
@abstractmethod
243+
def bucket(field_id: int, source_name: str, source_id: int, num_buckets: int) -> T:
244+
"""
245+
Visit bucket partition field
246+
"""
247+
248+
@abstractmethod
249+
def truncate(field_id: int, source_name: str, source_id: int, width: int) -> T:
250+
"""
251+
Visit truncate partition field
252+
"""
253+
254+
@abstractmethod
255+
def year(field_id: int, source_name: str, source_id: int) -> T:
256+
"""
257+
Visit year partition field
258+
"""
259+
260+
@abstractmethod
261+
def month(field_id: int, source_name: str, source_id: int) -> T:
262+
"""
263+
Visit month partition field
264+
"""
265+
266+
@abstractmethod
267+
def day(field_id: int, source_name: str, source_id: int) -> T:
268+
"""
269+
Visit day partition field
270+
"""
271+
272+
@abstractmethod
273+
def hour(field_id: int, source_name: str, source_id: int) -> T:
274+
"""
275+
Visit hour partition field
276+
"""
277+
278+
@abstractmethod
279+
def always_null(field_id: int, source_name: str, source_id: int) -> T:
280+
"""
281+
Visit void partition field
282+
"""
283+
284+
@abstractmethod
285+
def unknown(field_id: int, source_name: str, source_id: int, transform: str) -> T:
286+
"""
287+
Visit unknown partition field
288+
"""
289+
raise ValueError(f"Unknown transform {transform} is not supported")
290+
291+
class _PartitionNameGenerator(PartitionSpecVisitor[str]):
292+
def identity(field_id: int, source_name: str, source_id: int) -> str:
293+
return source_name
294+
295+
def bucket(field_id: int, source_name: str, source_id: int, num_buckets: int) -> str:
296+
return source_name + "_bucket_" + num_buckets
297+
298+
def truncate(field_id: int, source_name: str, source_id: int, width: int) -> str:
299+
return source_name + "_trunc_" + width
300+
301+
def year(field_id: int, source_name: str, source_id: int) -> str:
302+
return source_name + "_year"
303+
304+
def month(field_id: int, source_name: str, source_id: int) -> str:
305+
return source_name + "_month"
306+
307+
def day(field_id: int, source_name: str, source_id: int) -> str:
308+
return source_name + "_day"
309+
310+
def hour(field_id: int, source_name: str, source_id: int) -> str:
311+
return source_name + "_hour"
312+
313+
def always_null(field_id: int, source_name: str, source_id: int) -> str:
314+
return source_name + "_null"
315+
316+
R = TypeVar("R")
317+
318+
@staticmethod
319+
def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> [R]:
320+
results = []
321+
for field in spec.fields:
322+
results.append(_visit(schema, field, visitor))
323+
return results
324+
325+
@staticmethod
326+
def _visit(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> [R]:
327+
source_name = schema.find_column_name(field.source_id)
328+
transform = field.transform
329+
if isinstance(transform, IdentityTransform):
330+
visitor.identity(field.field_id, source_name, field.source_id)
331+
elif isinstance(transform, BucketTransform):
332+
visitor.bucket(field.field_id, source_name, field.source_id, cast(BucketTransform, transform).num_buckets)
333+
elif isinstance(transform, TruncateTransform):
334+
visitor.truncate(field.field_id, source_name, field.source_id, cast(TruncateTransform, transform).width)
335+
elif isinstance(transform, DayTransform):
336+
visitor.day(field.field_id, source_name, field.source_id)
337+
elif isinstance(transform, HourTransform):
338+
visitor.hour(field.field_id, source_name, field.source_id)
339+
elif isinstance(transform, YearTransform):
340+
visitor.year(field.field_id, source_name, field.source_id)
341+
elif isinstance(transform, VoidTransform):
342+
visitor.always_null(field.field_id, source_name, field.source_id)
343+
pass
344+
elif isinstance(transform, UnknownTransform):
345+
visitor.unknown(field.field_id, source_name, field.source_id)
346+
else:
347+
raise ValueError(f"Unknown transform {transform}")

0 commit comments

Comments
 (0)