diff --git a/pyiceberg/table/encryption.py b/pyiceberg/table/encryption.py new file mode 100644 index 0000000000..4cb1c67cdd --- /dev/null +++ b/pyiceberg/table/encryption.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +from pydantic import Field + +from pyiceberg.typedef import IcebergBaseModel + + +class EncryptedKey(IcebergBaseModel): + key_id: str = Field(alias="key-id", description="ID of the encryption key") + encrypted_key_metadata: bytes = Field( + alias="encrypted-key-metadata", description="Encrypted key and metadata, base64 encoded" + ) + encrypted_by_id: Optional[str] = Field( + alias="encrypted-by-id", description="Optional ID of the key used to encrypt or wrap `key-metadata`", default=None + ) + properties: Optional[dict[str, str]] = Field( + alias="properties", + description="A string to string map of additional metadata used by the table's encryption scheme", + default=None, + ) diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 9ab29815e9..6b51cff30d 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -27,6 +27,7 @@ from pyiceberg.exceptions import ValidationError from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry @@ -526,6 +527,7 @@ class TableMetadataV3(TableMetadataCommonFields, IcebergBaseModel): - Multi-argument transforms for partitioning and sorting - Row Lineage tracking - Binary deletion vectors + - Encryption Keys For more information: https://iceberg.apache.org/spec/?column-projection#version-3-extended-types-and-capabilities @@ -562,6 +564,9 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: next_row_id: Optional[int] = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: List[EncryptedKey] = Field(alias="encryption-keys", default=[]) + """The list of encryption keys for this table.""" + def model_dump_json( self, exclude_none: bool = True, exclude: Optional[Any] = None, by_alias: bool = True, **kwargs: Any ) -> str: diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 60ad7219e1..d64ad9e5ee 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -244,6 +244,7 @@ class Snapshot(IcebergBaseModel): manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file") summary: Optional[Summary] = Field(default=None) schema_id: Optional[int] = Field(alias="schema-id", default=None) + key_id: Optional[str] = Field(alias="key-id", default=None, description="The id of the encryption key") def __str__(self) -> str: """Return the string representation of the Snapshot class.""" diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 30315b0cc1..2c54e0ff7c 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -28,6 +28,7 @@ from pyiceberg.exceptions import CommitFailedException from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( @@ -90,6 +91,16 @@ class UpgradeFormatVersionUpdate(IcebergBaseModel): format_version: int = Field(alias="format-version") +class AddEncryptedKeyUpdate(IcebergBaseModel): + action: Literal["add-encryption-key"] = Field(default="add-encryption-key") + key: EncryptedKey = Field(alias="key") + + +class RemoveEncryptedKeyUpdate(IcebergBaseModel): + action: Literal["remove-encryption-key"] = Field(default="remove-encryption-key") + key_id: str = Field(alias="key-id") + + class AddSchemaUpdate(IcebergBaseModel): action: Literal["add-schema"] = Field(default="add-schema") schema_: Schema = Field(alias="schema") @@ -230,6 +241,8 @@ class RemovePartitionStatisticsUpdate(IcebergBaseModel): RemoveSchemasUpdate, SetPartitionStatisticsUpdate, RemovePartitionStatisticsUpdate, + AddEncryptedKeyUpdate, + RemoveEncryptedKeyUpdate, ], Field(discriminator="action"), ] @@ -595,6 +608,30 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"statistics": statistics}) +@_apply_table_update.register(AddEncryptedKeyUpdate) +def _(update: AddEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + context.add_update(update) + + if base_metadata.format_version <= 2: + raise ValueError("Cannot add encryption keys to Iceberg v1 or v2 tables") + + return base_metadata.model_copy(update={"encryption_keys": base_metadata.encryption_keys + [update.key]}) + + +@_apply_table_update.register(RemoveEncryptedKeyUpdate) +def _(update: RemoveEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + context.add_update(update) + + if base_metadata.format_version <= 2: + raise ValueError("Cannot remove encryption keys from Iceberg v1 or v2 tables") + + encryption_keys = [key for key in base_metadata.encryption_keys if key.key_id != update.key_id] + if len(encryption_keys) == len(base_metadata.encryption_keys): + raise ValueError(f"Encryption key {update.key_id} not found") + + return base_metadata.model_copy(update={"encryption_keys": encryption_keys}) + + @_apply_table_update.register(RemoveSchemasUpdate) def _(update: RemoveSchemasUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: # This method should error if any schemas do not exist. diff --git a/tests/conftest.py b/tests/conftest.py index 2b571d7320..c8b75510f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,7 +72,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table -from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3 from pyiceberg.transforms import DayTransform, IdentityTransform from pyiceberg.types import ( BinaryType, @@ -2468,6 +2468,18 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: ) +@pytest.fixture +def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV3(**example_table_metadata_v3) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + + @pytest.fixture def table_v2_orc(example_table_metadata_v2: Dict[str, Any]) -> Table: import copy diff --git a/tests/table/test_init.py b/tests/table/test_init.py index cd81df4d97..7b45d4496f 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import base64 import json import uuid from copy import copy @@ -49,6 +50,7 @@ TableIdentifier, _match_deletes_to_data_file, ) +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( @@ -66,6 +68,7 @@ ) from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile from pyiceberg.table.update import ( + AddEncryptedKeyUpdate, AddSnapshotUpdate, AddSortOrderUpdate, AssertCreate, @@ -76,6 +79,7 @@ AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, + RemoveEncryptedKeyUpdate, RemovePartitionStatisticsUpdate, RemovePropertiesUpdate, RemoveSchemasUpdate, @@ -1477,3 +1481,45 @@ def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_wi table_v2_with_statistics.metadata, (RemovePartitionStatisticsUpdate(snapshot_id=123456789),), ) + + +def test_add_encryption_key(table_v3: Table) -> None: + update = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) + + expected = """ + { + "key-id": "test", + "encrypted-key-metadata": "aGVsbG8=" + }""" + + assert table_v3.metadata.encryption_keys == [] + add_metadata = update_table_metadata(table_v3.metadata, (update,)) + assert len(add_metadata.encryption_keys) == 1 + + assert json.loads(add_metadata.encryption_keys[0].model_dump_json()) == json.loads(expected) + + +def test_remove_encryption_key(table_v3: Table) -> None: + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) + add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) + assert len(add_metadata.encryption_keys) == 1 + + update_remove = RemoveEncryptedKeyUpdate(key_id="test") + remove_metadata = update_table_metadata(add_metadata, (update_remove,)) + assert len(remove_metadata.encryption_keys) == 0 + + +def test_remove_non_existent_encryption_key(table_v3: Table) -> None: + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) + add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) + assert len(add_metadata.encryption_keys) == 1 + + update_remove = RemoveEncryptedKeyUpdate(key_id="non_existent_key") + with pytest.raises(ValueError, match=r"Encryption key non_existent_key not found"): + update_table_metadata(add_metadata, (update_remove,)) + + +def test_add_remove_encryption_key_v2_table(table_v2: Table) -> None: + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test_v2", encrypted_key_metadata=base64.b64encode(b"hello_v2"))) + with pytest.raises(ValueError, match=r"Cannot add encryption keys to Iceberg v1 or v2 table"): + update_table_metadata(table_v2.metadata, (update_add,)) diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index d26562ad8f..3a7391c288 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -139,7 +139,7 @@ def test_deserialize_snapshot_with_properties(snapshot_with_properties: Snapshot def test_snapshot_repr(snapshot: Snapshot) -> None: assert ( repr(snapshot) - == """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND), schema_id=3)""" + == """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND), schema_id=3, key_id=None)""" ) assert snapshot == eval(repr(snapshot)) @@ -147,7 +147,7 @@ def test_snapshot_repr(snapshot: Snapshot) -> None: def test_snapshot_with_properties_repr(snapshot_with_properties: Snapshot) -> None: assert ( repr(snapshot_with_properties) - == """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND, **{'foo': 'bar'}), schema_id=3)""" + == """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND, **{'foo': 'bar'}), schema_id=3, key_id=None)""" ) assert snapshot_with_properties == eval(repr(snapshot_with_properties))