Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pyiceberg/table/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Optional

from pydantic import Field

from pyiceberg.typedef import IcebergBaseModel


class EncryptedKey(IcebergBaseModel):
key_id: str = Field(alias="key-id", description="ID of the encryption key")
encrypted_key_metadata: bytes = Field(
alias="encrypted-key-metadata", description="Encrypted key and metadata, base64 encoded"
)
encrypted_by_id: Optional[str] = Field(
alias="encrypted-by-id", description="Optional ID of the key used to encrypt or wrap `key-metadata`", default=None
)
properties: Optional[dict[str, str]] = Field(
alias="properties",
description="A string to string map of additional metadata used by the table's encryption scheme",
default=None,
)
5 changes: 5 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids
from pyiceberg.schema import Schema, assign_fresh_schema_ids
from pyiceberg.table.encryption import EncryptedKey
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry
Expand Down Expand Up @@ -526,6 +527,7 @@ class TableMetadataV3(TableMetadataCommonFields, IcebergBaseModel):
- Multi-argument transforms for partitioning and sorting
- Row Lineage tracking
- Binary deletion vectors
- Encryption Keys

For more information:
https://iceberg.apache.org/spec/?column-projection#version-3-extended-types-and-capabilities
Expand Down Expand Up @@ -562,6 +564,9 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata:
next_row_id: Optional[int] = Field(alias="next-row-id", default=None)
"""A long higher than all assigned row IDs; the next snapshot's `first-row-id`."""

encryption_keys: List[EncryptedKey] = Field(alias="encryption-keys", default=[])
"""The list of encryption keys for this table."""

def model_dump_json(
self, exclude_none: bool = True, exclude: Optional[Any] = None, by_alias: bool = True, **kwargs: Any
) -> str:
Expand Down
1 change: 1 addition & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class Snapshot(IcebergBaseModel):
manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file")
summary: Optional[Summary] = Field(default=None)
schema_id: Optional[int] = Field(alias="schema-id", default=None)
key_id: Optional[str] = Field(alias="key-id", default=None, description="The id of the encryption key")

def __str__(self) -> str:
"""Return the string representation of the Snapshot class."""
Expand Down
37 changes: 37 additions & 0 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pyiceberg.exceptions import CommitFailedException
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.encryption import EncryptedKey
from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
Expand Down Expand Up @@ -90,6 +91,16 @@ class UpgradeFormatVersionUpdate(IcebergBaseModel):
format_version: int = Field(alias="format-version")


class AddEncryptedKeyUpdate(IcebergBaseModel):
action: Literal["add-encryption-key"] = Field(default="add-encryption-key")
key: EncryptedKey = Field(alias="key")


class RemoveEncryptedKeyUpdate(IcebergBaseModel):
action: Literal["remove-encryption-key"] = Field(default="remove-encryption-key")
key_id: str = Field(alias="key-id")


class AddSchemaUpdate(IcebergBaseModel):
action: Literal["add-schema"] = Field(default="add-schema")
schema_: Schema = Field(alias="schema")
Expand Down Expand Up @@ -230,6 +241,8 @@ class RemovePartitionStatisticsUpdate(IcebergBaseModel):
RemoveSchemasUpdate,
SetPartitionStatisticsUpdate,
RemovePartitionStatisticsUpdate,
AddEncryptedKeyUpdate,
RemoveEncryptedKeyUpdate,
],
Field(discriminator="action"),
]
Expand Down Expand Up @@ -595,6 +608,30 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta
return base_metadata.model_copy(update={"statistics": statistics})


@_apply_table_update.register(AddEncryptedKeyUpdate)
def _(update: AddEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
context.add_update(update)

if base_metadata.format_version <= 2:
raise ValueError("Cannot add encryption keys to Iceberg v1 or v2 tables")

return base_metadata.model_copy(update={"encryption_keys": base_metadata.encryption_keys + [update.key]})


@_apply_table_update.register(RemoveEncryptedKeyUpdate)
def _(update: RemoveEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
context.add_update(update)

if base_metadata.format_version <= 2:
raise ValueError("Cannot remove encryption keys from Iceberg v1 or v2 tables")

encryption_keys = [key for key in base_metadata.encryption_keys if key.key_id != update.key_id]
if len(encryption_keys) == len(base_metadata.encryption_keys):
raise ValueError(f"Encryption key {update.key_id} not found")

return base_metadata.model_copy(update={"encryption_keys": encryption_keys})


@_apply_table_update.register(RemoveSchemasUpdate)
def _(update: RemoveSchemasUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
# This method should error if any schemas do not exist.
Expand Down
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3
from pyiceberg.transforms import DayTransform, IdentityTransform
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -2468,6 +2468,18 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
)


@pytest.fixture
def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV3(**example_table_metadata_v3)
return Table(
identifier=("database", "table"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2_orc(example_table_metadata_v2: Dict[str, Any]) -> Table:
import copy
Expand Down
46 changes: 46 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
import base64
import json
import uuid
from copy import copy
Expand Down Expand Up @@ -49,6 +50,7 @@
TableIdentifier,
_match_deletes_to_data_file,
)
from pyiceberg.table.encryption import EncryptedKey
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
Expand All @@ -66,6 +68,7 @@
)
from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile
from pyiceberg.table.update import (
AddEncryptedKeyUpdate,
AddSnapshotUpdate,
AddSortOrderUpdate,
AssertCreate,
Expand All @@ -76,6 +79,7 @@
AssertLastAssignedPartitionId,
AssertRefSnapshotId,
AssertTableUUID,
RemoveEncryptedKeyUpdate,
RemovePartitionStatisticsUpdate,
RemovePropertiesUpdate,
RemoveSchemasUpdate,
Expand Down Expand Up @@ -1477,3 +1481,45 @@ def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_wi
table_v2_with_statistics.metadata,
(RemovePartitionStatisticsUpdate(snapshot_id=123456789),),
)


def test_add_encryption_key(table_v3: Table) -> None:
update = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello")))

expected = """
{
"key-id": "test",
"encrypted-key-metadata": "aGVsbG8="
}"""

assert table_v3.metadata.encryption_keys == []
add_metadata = update_table_metadata(table_v3.metadata, (update,))
assert len(add_metadata.encryption_keys) == 1

assert json.loads(add_metadata.encryption_keys[0].model_dump_json()) == json.loads(expected)


def test_remove_encryption_key(table_v3: Table) -> None:
update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello")))
add_metadata = update_table_metadata(table_v3.metadata, (update_add,))
assert len(add_metadata.encryption_keys) == 1

update_remove = RemoveEncryptedKeyUpdate(key_id="test")
remove_metadata = update_table_metadata(add_metadata, (update_remove,))
assert len(remove_metadata.encryption_keys) == 0


def test_remove_non_existent_encryption_key(table_v3: Table) -> None:
update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello")))
add_metadata = update_table_metadata(table_v3.metadata, (update_add,))
assert len(add_metadata.encryption_keys) == 1

update_remove = RemoveEncryptedKeyUpdate(key_id="non_existent_key")
with pytest.raises(ValueError, match=r"Encryption key non_existent_key not found"):
update_table_metadata(add_metadata, (update_remove,))


def test_add_remove_encryption_key_v2_table(table_v2: Table) -> None:
update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test_v2", encrypted_key_metadata=base64.b64encode(b"hello_v2")))
with pytest.raises(ValueError, match=r"Cannot add encryption keys to Iceberg v1 or v2 table"):
update_table_metadata(table_v2.metadata, (update_add,))
4 changes: 2 additions & 2 deletions tests/table/test_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ def test_deserialize_snapshot_with_properties(snapshot_with_properties: Snapshot
def test_snapshot_repr(snapshot: Snapshot) -> None:
assert (
repr(snapshot)
== """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND), schema_id=3)"""
== """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND), schema_id=3, key_id=None)"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we want to omit key_id for < V3 tables instead of serialising it.

The spec says:

image

with blanks for format versions 1 and 2. #2146 (comment) then makes me think we shouldn't write this field for those versions. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1973 makes it sound like we don't want to version these. What are your thoughts?

)
assert snapshot == eval(repr(snapshot))


def test_snapshot_with_properties_repr(snapshot_with_properties: Snapshot) -> None:
assert (
repr(snapshot_with_properties)
== """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND, **{'foo': 'bar'}), schema_id=3)"""
== """Snapshot(snapshot_id=25, parent_snapshot_id=19, sequence_number=200, timestamp_ms=1602638573590, manifest_list='s3:/a/b/c.avro', summary=Summary(Operation.APPEND, **{'foo': 'bar'}), schema_id=3, key_id=None)"""
)
assert snapshot_with_properties == eval(repr(snapshot_with_properties))

Expand Down