Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .tool-versions
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
poetry 1.5.1
poetry 1.8.5
python 3.10.12 3.8.12 3.9.12 3.11.5
16 changes: 8 additions & 8 deletions nhs_aws_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from mypy_boto3_lambda.client import LambdaClient
from mypy_boto3_logs.client import CloudWatchLogsClient
from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.service_resource import Bucket, Object, S3ServiceResource
from mypy_boto3_s3.service_resource import Bucket, Object, ObjectSummary, ObjectVersion, S3ServiceResource
from mypy_boto3_s3.type_defs import (
CompletedPartTypeDef,
DeleteMarkerEntryTypeDef,
Expand Down Expand Up @@ -203,7 +203,7 @@ def register_retry_handler(

retry_quota = RetryQuotaChecker(quota.RetryQuota())

max_attempts = client.meta.config.retries.get("total_max_attempts", DEFAULT_MAX_ATTEMPTS)
max_attempts = client.meta.config.retries.get("total_max_attempts", DEFAULT_MAX_ATTEMPTS) # type: ignore[attr-defined]

service_id = client.meta.service_model.service_id
service_event_name = service_id.hyphenize()
Expand All @@ -228,12 +228,12 @@ def register_retry_handler(
# Re-register with our own handler
client.meta.events.register(
f"needs-retry.{service_event_name}",
handler.needs_retry,
handler.needs_retry, # type: ignore[arg-type]
unique_id=unique_id,
)

def on_response_received(**kwargs):
if on_error is not None and kwargs.get("exception") or kwargs.get("parsed_response", {}).get("Error"):
if (on_error is not None and kwargs.get("exception")) or kwargs.get("parsed_response", {}).get("Error"):
assert on_error
on_error(**kwargs)

Expand Down Expand Up @@ -413,7 +413,7 @@ def s3_get_all_keys(
page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
keys = []
for page in page_iterator:
keys.extend([content["Key"] for content in page["Contents"]])
keys.extend([content["Key"] for content in page.get("Contents", [])])

return keys

Expand Down Expand Up @@ -524,7 +524,7 @@ def s3_ls(
versioning: bool = False,
session: Optional[Session] = None,
config: Optional[Config] = None,
) -> Generator[Object, None, None]:
) -> Generator[Union[ObjectSummary, ObjectVersion], None, None]:
_, bucket, path = s3_split_path(uri)

yield from s3_list_bucket(
Expand All @@ -540,7 +540,7 @@ def s3_list_bucket(
versioning: bool = False,
session: Optional[Session] = None,
config: Optional[Config] = None,
) -> Generator[Object, None, None]:
) -> Generator[Union[ObjectSummary, ObjectVersion], None, None]:
"""list contents of S3 bucket based on filter criteria and versioning flag

Args:
Expand All @@ -553,7 +553,7 @@ def s3_list_bucket(
config (Config): optional botocore config

Returns:
Generator[object, None, None]: resulting objects or versions
Generator[object, None, None]: resulting objects (ObjectSummary) or versions (ObjectVersion)
"""
buck = s3_bucket(bucket, session=session, config=config)
bc_objects = buck.object_versions if versioning else buck.objects
Expand Down
2 changes: 1 addition & 1 deletion nhs_aws_helpers/dynamodb_model_store/base_model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def table_key_fields(self) -> List[str]:

@classmethod
def _deserialise_field(cls, field: dataclasses.Field, value: Any, **kwargs) -> Any:
return cls.deserialise_value(field.type, value, **kwargs)
return cls.deserialise_value(cast(type, field.type), value, **kwargs)

@classmethod
def deserialise_value(cls, value_type: type, value: Any, **kwargs) -> Any: # noqa: C901
Expand Down
42 changes: 35 additions & 7 deletions nhs_aws_helpers/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
ddb_table,
dynamodb,
events_client,
s3_client,
s3_delete_all_versions,
s3_resource,
sqs_resource,
)

__all__ = [
"temp_s3_bucket_session_fixture",
"temp_s3_bucket_fixture",
"temp_event_bus_fixture",
"temp_queue_fixture",
"temp_fifo_queue_fixture",
"clone_schema",
"temp_dynamodb_table",
"temp_event_bus_fixture",
"temp_fifo_queue_fixture",
"temp_queue_fixture",
"temp_s3_bucket_fixture",
"temp_s3_bucket_session_fixture",
"temp_versioned_s3_bucket_fixture",
]


Expand All @@ -38,11 +41,13 @@ def temp_s3_bucket_session_fixture() -> Generator[Bucket, None, None]:

bucket_name = f"temp-{petname.generate()}"
bucket = resource.create_bucket(
Bucket=bucket_name, CreateBucketConfiguration=CreateBucketConfigurationTypeDef(LocationConstraint="eu-west-2")
Bucket=bucket_name,
CreateBucketConfiguration=CreateBucketConfigurationTypeDef(LocationConstraint="eu-west-2"),
)
yield bucket

bucket.objects.all().delete()
s3_delete_all_versions(bucket.name, "", dry_run=False)

bucket.delete()


Expand All @@ -61,6 +66,29 @@ def temp_s3_bucket_fixture(session_temp_s3_bucket: Bucket) -> Bucket:
return bucket


@pytest.fixture(name="temp_versioned_s3_bucket")
def temp_versioned_s3_bucket_fixture(session_temp_s3_bucket: Bucket) -> Bucket:
"""
yields a temporary s3 bucket for use in unit tests

Returns:
Bucket: a temporary empty s3 bucket
"""
bucket = session_temp_s3_bucket

s3_client().put_bucket_versioning(
Bucket=bucket.name,
VersioningConfiguration={
"MFADelete": "Disabled",
"Status": "Enabled",
},
)

bucket.objects.all().delete()

return bucket


@pytest.fixture(name="temp_event_bus")
def temp_event_bus_fixture() -> Generator[Tuple[Queue, str], None, None]:
"""
Expand Down
471 changes: 264 additions & 207 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ repository = "https://github.com/NHSDigital/nhs-aws-helpers"
[tool.poetry.dependencies]
# core dependencies
python = ">=3.8,<4.0"
boto3 = "^1.35.29"
boto3-stubs = {extras = ["s3", "ssm", "secretsmanager", "dynamodb", "stepfunctions", "sqs", "lambda", "logs", "ses", "sns", "events", "kms", "firehose", "athena"], version = "^1.35.29"}
botocore-stubs = "^1.35.30"
boto3 = "^1.35.93"
boto3-stubs = {extras = ["s3", "ssm", "secretsmanager", "dynamodb", "stepfunctions", "sqs", "lambda", "logs", "ses", "sns", "events", "kms", "firehose", "athena"], version = "^1.35.93"}
botocore-stubs = "^1.35.93"


[tool.setuptools.package-data]
Expand Down Expand Up @@ -81,7 +81,6 @@ lint.select = [
]
src = ["."]
lint.ignore = [
"PT004"
]
exclude = [
".git",
Expand Down
131 changes: 129 additions & 2 deletions tests/aws_tests.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import logging
import os
from typing import Any, List
from typing import Any, List, cast
from uuid import uuid4

import pytest
from botocore.config import Config
from botocore.exceptions import ClientError
from mypy_boto3_s3.service_resource import Bucket
from mypy_boto3_s3.service_resource import Bucket, ObjectSummary, ObjectVersion
from pytest_httpserver import HTTPServer

from nhs_aws_helpers import (
dynamodb_retry_backoff,
post_create_client,
register_config_default,
register_retry_handler,
s3_build_uri,
s3_client,
s3_delete_keys,
s3_get_all_keys,
s3_list_folders,
s3_ls,
s3_resource,
s3_upload_multipart_from_copy,
transaction_cancellation_reasons,
Expand Down Expand Up @@ -59,6 +63,129 @@ def fail_me():
assert len(calls) == 3


def test_s3_delete_keys_all(temp_s3_bucket: Bucket):

keys = [str(key) for key in range(1003)]

for key in keys:
temp_s3_bucket.put_object(Key=key, Body=f"Some data {uuid4().hex}".encode())

deleted_keys = s3_delete_keys(keys, temp_s3_bucket.name)

remaining_keys = list(s3_get_all_keys(temp_s3_bucket.name, ""))

assert len(deleted_keys) == len(keys)
assert set(deleted_keys) == set(keys)
assert len(remaining_keys) == 0


def test_s3_delete_keys_partial(temp_s3_bucket: Bucket):

keys = [str(key) for key in range(1003)]

for key in keys:
temp_s3_bucket.put_object(Key=key, Body=f"Some data {uuid4().hex}".encode())

leaving_keys = [keys.pop(i) for i in (1002, 456, 3)]

deleted_keys = s3_delete_keys(keys, temp_s3_bucket.name)

remaining_keys = list(s3_get_all_keys(temp_s3_bucket.name, ""))

assert len(deleted_keys) == len(keys)
assert set(deleted_keys) == set(keys)
assert len(remaining_keys) == len(leaving_keys)
assert set(remaining_keys) == set(leaving_keys)


@pytest.mark.parametrize(
"keys",
[(), ("a",), ("a/b",), ("a/b", "a/c")],
)
def test_s3_get_all_keys(temp_s3_bucket: Bucket, keys: List[str]):
for key in keys:
temp_s3_bucket.put_object(Key=key, Body=f"Some data {uuid4().hex}".encode())

result_keys = list(s3_get_all_keys(temp_s3_bucket.name, ""))

assert len(result_keys) == len(keys)
assert set(result_keys) == set(keys)


@pytest.mark.parametrize(
"keys",
[(), ("a",), ("a/b",), ("a/b", "a/c")],
)
def test_s3_get_all_keys_under_prefix(temp_s3_bucket: Bucket, keys: List[str]):
expected_folder = uuid4().hex

# Keys that shouldn't be included
for key in ("x", "y/z"):
temp_s3_bucket.put_object(Key=key, Body=f"Some data {uuid4().hex}".encode())

for key in keys:
temp_s3_bucket.put_object(Key=f"{expected_folder}/{key}", Body=f"Some data {uuid4().hex}".encode())

result_keys = list(s3_get_all_keys(temp_s3_bucket.name, expected_folder))

assert len(result_keys) == len(keys)
assert set(result_keys) == {f"{expected_folder}/{key}" for key in keys}


def test_s3_ls(temp_s3_bucket: Bucket):
expected_folder = uuid4().hex
temp_s3_bucket.put_object(Key=f"{expected_folder}/1/a.txt", Body=f"Some data {uuid4().hex}".encode())
temp_s3_bucket.put_object(Key=f"{expected_folder}/1/b.txt", Body=f"Some data {uuid4().hex}".encode())
temp_s3_bucket.put_object(Key=f"{expected_folder}/2/c.txt", Body=f"Some data {uuid4().hex}".encode())

files = cast(List[ObjectSummary], list(s3_ls(s3_build_uri(temp_s3_bucket.name, expected_folder))))

assert len(files) == 3
assert {f.key for f in files} == {
f"{expected_folder}/1/a.txt",
f"{expected_folder}/1/b.txt",
f"{expected_folder}/2/c.txt",
}


def test_s3_ls_versioning_on_non_versioned_bucket(temp_s3_bucket: Bucket):
expected_folder = uuid4().hex
temp_s3_bucket.put_object(Key=f"{expected_folder}/1/a.txt", Body=f"Some data {uuid4().hex}".encode())
temp_s3_bucket.put_object(Key=f"{expected_folder}/1/b.txt", Body=f"Some data {uuid4().hex}".encode())
temp_s3_bucket.put_object(Key=f"{expected_folder}/2/c.txt", Body=f"Some data {uuid4().hex}".encode())
temp_s3_bucket.put_object(Key=f"{expected_folder}/2/c.txt", Body=f"Some new data {uuid4().hex}".encode())

files = cast(List[ObjectVersion], list(s3_ls(s3_build_uri(temp_s3_bucket.name, expected_folder), versioning=True)))

assert len(files) == 3
assert {f.key for f in files} == {
f"{expected_folder}/1/a.txt",
f"{expected_folder}/1/b.txt",
f"{expected_folder}/2/c.txt",
}
assert len({f.id for f in files}) == 1


def test_s3_ls_versioning(temp_versioned_s3_bucket: Bucket):
expected_folder = uuid4().hex
temp_versioned_s3_bucket.put_object(Key=f"{expected_folder}/1/a.txt", Body=f"Some data {uuid4().hex}".encode())
temp_versioned_s3_bucket.put_object(Key=f"{expected_folder}/1/b.txt", Body=f"Some data {uuid4().hex}".encode())
temp_versioned_s3_bucket.put_object(Key=f"{expected_folder}/2/c.txt", Body=f"Some data {uuid4().hex}".encode())
temp_versioned_s3_bucket.put_object(Key=f"{expected_folder}/2/c.txt", Body=f"Some new data {uuid4().hex}".encode())

files = cast(
List[ObjectVersion], list(s3_ls(s3_build_uri(temp_versioned_s3_bucket.name, expected_folder), versioning=True))
)

assert len(files) == 4
assert {f.key for f in files} == {
f"{expected_folder}/1/a.txt",
f"{expected_folder}/1/b.txt",
f"{expected_folder}/2/c.txt",
}
assert len({f.id for f in files}) == 4


def test_s3_list_folders_root(temp_s3_bucket: Bucket):
expected_folder = uuid4().hex

Expand Down
Loading