1515# specific language governing permissions and limitations
1616# under the License.
1717from __future__ import annotations
18-
18+ from abc import ABC , abstractmethod
1919from functools import cached_property
2020from typing import (
2121 Any ,
22+ cast ,
2223 Dict ,
24+ Generic ,
2325 List ,
2426 Optional ,
2527 Tuple ,
28+ TypeVar
2629)
2730
2831from pydantic import (
3437from typing_extensions import Annotated
3538
3639from pyiceberg .schema import Schema
37- from pyiceberg .transforms import Transform , parse_transform
40+ from pyiceberg .transforms import (
41+ BucketTransform ,
42+ DayTransform ,
43+ IdentityTransform ,
44+ HourTransform ,
45+ Transform ,
46+ TruncateTransform ,
47+ UnknownTransform ,
48+ VoidTransform ,
49+ YearTransform ,
50+ parse_transform
51+ )
52+
3853from pyiceberg .typedef import IcebergBaseModel
3954from pyiceberg .types import NestedField , StructType
4055
@@ -215,3 +230,118 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
215230 )
216231 )
217232 return PartitionSpec (* partition_fields , spec_id = INITIAL_PARTITION_SPEC_ID )
233+
234+ T = TypeVar ("T" )
235+ class PartitionSpecVisitor (Generic [T ], ABC ):
236+ @abstractmethod
237+ def identity (field_id : int , source_name : str , source_id : int ) -> T :
238+ """
239+ Visit identity partition field
240+ """
241+
242+ @abstractmethod
243+ def bucket (field_id : int , source_name : str , source_id : int , num_buckets : int ) -> T :
244+ """
245+ Visit bucket partition field
246+ """
247+
248+ @abstractmethod
249+ def truncate (field_id : int , source_name : str , source_id : int , width : int ) -> T :
250+ """
251+ Visit truncate partition field
252+ """
253+
254+ @abstractmethod
255+ def year (field_id : int , source_name : str , source_id : int ) -> T :
256+ """
257+ Visit year partition field
258+ """
259+
260+ @abstractmethod
261+ def month (field_id : int , source_name : str , source_id : int ) -> T :
262+ """
263+ Visit month partition field
264+ """
265+
266+ @abstractmethod
267+ def day (field_id : int , source_name : str , source_id : int ) -> T :
268+ """
269+ Visit day partition field
270+ """
271+
272+ @abstractmethod
273+ def hour (field_id : int , source_name : str , source_id : int ) -> T :
274+ """
275+ Visit hour partition field
276+ """
277+
278+ @abstractmethod
279+ def always_null (field_id : int , source_name : str , source_id : int ) -> T :
280+ """
281+ Visit void partition field
282+ """
283+
284+ @abstractmethod
285+ def unknown (field_id : int , source_name : str , source_id : int , transform : str ) -> T :
286+ """
287+ Visit unknown partition field
288+ """
289+ raise ValueError (f"Unknown transform { transform } is not supported" )
290+
291+ class _PartitionNameGenerator (PartitionSpecVisitor [str ]):
292+ def identity (field_id : int , source_name : str , source_id : int ) -> str :
293+ return source_name
294+
295+ def bucket (field_id : int , source_name : str , source_id : int , num_buckets : int ) -> str :
296+ return source_name + "_bucket_" + num_buckets
297+
298+ def truncate (field_id : int , source_name : str , source_id : int , width : int ) -> str :
299+ return source_name + "_trunc_" + width
300+
301+ def year (field_id : int , source_name : str , source_id : int ) -> str :
302+ return source_name + "_year"
303+
304+ def month (field_id : int , source_name : str , source_id : int ) -> str :
305+ return source_name + "_month"
306+
307+ def day (field_id : int , source_name : str , source_id : int ) -> str :
308+ return source_name + "_day"
309+
310+ def hour (field_id : int , source_name : str , source_id : int ) -> str :
311+ return source_name + "_hour"
312+
313+ def always_null (field_id : int , source_name : str , source_id : int ) -> str :
314+ return source_name + "_null"
315+
316+ R = TypeVar ("R" )
317+
318+ @staticmethod
319+ def _visit (spec : PartitionSpec , schema : Schema , visitor : PartitionSpecVisitor [R ]) -> [R ]:
320+ results = []
321+ for field in spec .fields :
322+ results .append (_visit (schema , field , visitor ))
323+ return results
324+
325+ @staticmethod
326+ def _visit (schema : Schema , field : PartitionField , visitor : PartitionSpecVisitor [R ]) -> [R ]:
327+ source_name = schema .find_column_name (field .source_id )
328+ transform = field .transform
329+ if isinstance (transform , IdentityTransform ):
330+ visitor .identity (field .field_id , source_name , field .source_id )
331+ elif isinstance (transform , BucketTransform ):
332+ visitor .bucket (field .field_id , source_name , field .source_id , cast (BucketTransform , transform ).num_buckets )
333+ elif isinstance (transform , TruncateTransform ):
334+ visitor .truncate (field .field_id , source_name , field .source_id , cast (TruncateTransform , transform ).width )
335+ elif isinstance (transform , DayTransform ):
336+ visitor .day (field .field_id , source_name , field .source_id )
337+ elif isinstance (transform , HourTransform ):
338+ visitor .hour (field .field_id , source_name , field .source_id )
339+ elif isinstance (transform , YearTransform ):
340+ visitor .year (field .field_id , source_name , field .source_id )
341+ elif isinstance (transform , VoidTransform ):
342+ visitor .always_null (field .field_id , source_name , field .source_id )
343+ pass
344+ elif isinstance (transform , UnknownTransform ):
345+ visitor .unknown (field .field_id , source_name , field .source_id )
346+ else :
347+ raise ValueError (f"Unknown transform { transform } " )
0 commit comments