Skip to content
185 changes: 147 additions & 38 deletions lambdas/services/base/s3_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, Mapping
from urllib import parse

import boto3
from botocore.client import Config as BotoConfig
Expand All @@ -18,6 +19,9 @@ class S3Service:
EXPIRED_SESSION_WARNING = "Expired session, creating a new role session"
S3_PREFIX = "s3://"

DEFAULT_AUTODELETE_TAG_KEY = "autodelete"
DEFAULT_AUTODELETE_TAG_VALUE = "true"

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
Expand All @@ -43,51 +47,120 @@ def __init__(self, custom_aws_role=None):
self.custom_aws_role, "s3", config=self.config
)

def _refresh_custom_session_if_needed(self) -> None:
if not self.custom_client:
return
if datetime.now(timezone.utc) > self.expiration_time - timedelta(minutes=10):
logger.info(S3Service.EXPIRED_SESSION_WARNING)
self.custom_client, self.expiration_time = self.iam_service.assume_role(
self.custom_aws_role, "s3", config=self.config
)

@staticmethod
def build_tagging_query(tags: Mapping[str, str] | None) -> str:
"""
S3 expects Tagging as a URL-encoded querystring, e.g. "autodelete=true&foo=bar"
"""
if not tags:
return ""
return parse.urlencode(dict(tags))

@staticmethod
def ensure_autodelete_tag(
tags: Mapping[str, str] | None,
tag_key: str = DEFAULT_AUTODELETE_TAG_KEY,
tag_value: str = DEFAULT_AUTODELETE_TAG_VALUE,
) -> dict[str, str]:
out = dict(tags or {})
out.setdefault(tag_key, tag_value)
return out


# S3 Location should be a minimum of a s3_object_key but can also be a directory location in the form of
# {{directory}}/{{s3_object_key}}
def create_upload_presigned_url(self, s3_bucket_name: str, s3_object_location: str):
"""
Backwards-compatible wrapper for presigned POST without enforced tags.
"""
if self.custom_client:
if datetime.now(timezone.utc) > self.expiration_time - timedelta(
minutes=10
):
logger.info(S3Service.EXPIRED_SESSION_WARNING)
self.custom_client, self.expiration_time = self.iam_service.assume_role(
self.custom_aws_role, "s3", config=self.config
)
self._refresh_custom_session_if_needed()
return self.custom_client.generate_presigned_post(
s3_bucket_name,
s3_object_location,
Fields=None,
Conditions=None,
ExpiresIn=self.presigned_url_expiry,
)

def create_put_presigned_url(self, s3_bucket_name: str, file_key: str):
if self.custom_client:
if datetime.now(timezone.utc) > self.expiration_time - timedelta(
minutes=10
):
logger.info(S3Service.EXPIRED_SESSION_WARNING)
self.custom_client, self.expiration_time = self.iam_service.assume_role(
self.custom_aws_role, "s3", config=self.config
)
logger.info("Generating presigned URL")
return self.custom_client.generate_presigned_url(
"put_object",
Params={"Bucket": s3_bucket_name, "Key": file_key},
ExpiresIn=self.presigned_url_expiry,
)
return None

def create_upload_presigned_post(
self,
s3_bucket_name: str,
s3_object_location: str,
tags: Mapping[str, str] | None = None,
require_autodelete: bool = False,
):
if not self.custom_client:
return None

self._refresh_custom_session_if_needed()

final_tags = (
self.ensure_autodelete_tag(tags) if require_autodelete else dict(tags or {})
)

fields: dict[str, Any] = {}
conditions: list[Any] = []

if final_tags:
tagging = self.build_tagging_query(final_tags)
# For POST policy, tagging uses the "tagging" form field
fields["tagging"] = tagging
conditions.append({"tagging": tagging})

return self.custom_client.generate_presigned_post(
s3_bucket_name,
s3_object_location,
Fields=fields or None,
Conditions=conditions or None,
ExpiresIn=self.presigned_url_expiry,
)

def create_put_presigned_url(
self,
s3_bucket_name: str,
file_key: str,
tags: Mapping[str, str] | None = None,
require_autodelete: bool = False,
extra_params: Mapping[str, Any] | None = None,
):
if not self.custom_client:
return None

self._refresh_custom_session_if_needed()

final_tags = (
self.ensure_autodelete_tag(tags) if require_autodelete else dict(tags or {})
)

params: dict[str, Any] = {"Bucket": s3_bucket_name, "Key": file_key}

if final_tags:
params["Tagging"] = self.build_tagging_query(final_tags)

if extra_params:
params.update(extra_params)

logger.info("Generating presigned URL")
return self.custom_client.generate_presigned_url(
"put_object",
Params=params,
ExpiresIn=self.presigned_url_expiry,
)

def create_download_presigned_url(self, s3_bucket_name: str, file_key: str):
if self.custom_client:
if datetime.now(timezone.utc) > self.expiration_time - timedelta(
minutes=10
):
logger.info(S3Service.EXPIRED_SESSION_WARNING)
self.custom_client, self.expiration_time = self.iam_service.assume_role(
self.custom_aws_role, "s3", config=self.config
)
self._refresh_custom_session_if_needed()
logger.info("Generating presigned URL")
return self.custom_client.generate_presigned_url(
"get_object",
Expand Down Expand Up @@ -143,11 +216,9 @@ def copy_across_bucket(
if_none_match,
False,
)
else:
raise e
else:
logger.error(f"Copy failed: {e}")
raise e
raise
logger.error(f"Copy failed: {e}")
raise

def delete_object(
self, s3_bucket_name: str, file_key: str, version_id: str | None = None
Expand All @@ -159,6 +230,34 @@ def delete_object(
Bucket=s3_bucket_name, Key=file_key, VersionId=version_id
)

def delete_object_hard(self, s3_bucket_name: str, file_key: str) -> None:
"""
Deletes ALL versions + delete markers for a given key.
"""
try:
paginator = self.client.get_paginator("list_object_versions")
to_delete: list[dict[str, str]] = []

for page in paginator.paginate(Bucket=s3_bucket_name, Prefix=file_key):
for v in page.get("Versions", []):
if v.get("Key") == file_key:
to_delete.append({"Key": file_key, "VersionId": v["VersionId"]})
for m in page.get("DeleteMarkers", []):
if m.get("Key") == file_key:
to_delete.append({"Key": file_key, "VersionId": m["VersionId"]})

for i in range(0, len(to_delete), 1000):
chunk = to_delete[i : i + 1000]
if chunk:
self.client.delete_objects(
Bucket=s3_bucket_name,
Delete={"Objects": chunk, "Quiet": True},
)
except ClientError as e:
logger.error(f"Hard delete failed for s3://{s3_bucket_name}/{file_key}: {e}")
raise


def create_object_tag(
self, s3_bucket_name: str, file_key: str, tag_key: str, tag_value: str
):
Expand Down Expand Up @@ -202,7 +301,7 @@ def file_exist_on_s3(self, s3_bucket_name: str, file_key: str) -> bool:
):
return False
logger.error(str(e), {"Result": "Failed to check if file exists on s3"})
raise e
raise

def list_all_objects(self, bucket_name: str) -> list[dict]:
s3_paginator = self.client.get_paginator("list_objects_v2")
Expand Down Expand Up @@ -236,20 +335,30 @@ def upload_file_obj(
s3_bucket_name: str,
file_key: str,
extra_args: Mapping[str, Any] = None,
require_autodelete: bool = False,
tags: Mapping[str, str] | None = None,
):
try:
final_extra_args: dict[str, Any] = dict(extra_args or {})

if require_autodelete:
final_tags = self.ensure_autodelete_tag(tags)
final_extra_args["Tagging"] = self.build_tagging_query(final_tags)
elif tags:
final_extra_args["Tagging"] = self.build_tagging_query(tags)

self.client.upload_fileobj(
Fileobj=file_obj,
Bucket=s3_bucket_name,
Key=file_key,
ExtraArgs=extra_args or {},
ExtraArgs=final_extra_args,
)
logger.info(f"Uploaded file object to s3://{s3_bucket_name}/{file_key}")
except ClientError as e:
logger.error(
f"Failed to upload file object to s3://{s3_bucket_name}/{file_key} - {e}"
)
raise e
raise

def save_or_create_file(self, source_bucket: str, file_key: str, body: bytes):
return self.client.put_object(
Expand Down
18 changes: 9 additions & 9 deletions lambdas/services/fhir_document_reference_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _store_binary_in_s3(
file_obj=binary_file,
s3_bucket_name=document_reference.s3_bucket_name,
file_key=document_reference.s3_upload_key,
require_autodelete=True,
)
logger.info(
f"Successfully stored binary content in S3: {document_reference.s3_upload_key}"
Expand All @@ -77,7 +78,9 @@ def _create_s3_presigned_url(self, document_reference: DocumentReference) -> str
"""Create a pre-signed URL for uploading a file"""
try:
response = self.s3_service.create_put_presigned_url(
document_reference.s3_bucket_name, document_reference.s3_upload_key
s3_bucket_name=document_reference.s3_bucket_name,
file_key=document_reference.s3_upload_key,
require_autodelete=False,
)
logger.info(
f"Successfully created pre-signed URL for {document_reference.s3_upload_key}"
Expand Down Expand Up @@ -138,10 +141,9 @@ def _get_document_reference(self, document_id: str, table) -> DocumentReference:
if len(documents) > 0:
logger.info("Document found for given id")
return documents[0]
else:
raise FhirDocumentReferenceException(
f"Did not find any documents for document ID {document_id}"
)
raise FhirDocumentReferenceException(
f"Did not find any documents for document ID {document_id}"
)

def _determine_document_type(self, fhir_doc: FhirDocumentReference) -> SnomedCode:
"""Determine the document type based on SNOMED code in the FHIR document"""
Expand Down Expand Up @@ -190,13 +192,10 @@ def _create_fhir_response(
presigned_url: str,
) -> str:
"""Create a FHIR response document"""

if presigned_url:
attachment_url = presigned_url
else:
document_retrieve_endpoint = os.getenv(
"DOCUMENT_RETRIEVE_ENDPOINT_APIM", ""
)
document_retrieve_endpoint = os.getenv("DOCUMENT_RETRIEVE_ENDPOINT_APIM", "")
attachment_url = (
document_retrieve_endpoint
+ "/"
Expand Down Expand Up @@ -252,6 +251,7 @@ def _handle_document_save(
presigned_url = self._create_s3_presigned_url(document_reference)
except FhirDocumentReferenceException:
raise DocumentRefException(500, LambdaError.InternalServerError)

try:
# Save document reference to DynamoDB
self._save_document_reference_to_dynamo(dynamo_table, document_reference)
Expand Down
1 change: 1 addition & 0 deletions lambdas/services/post_document_review_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def create_review_document_upload_presigned_url(
presign_url_response = self.s3_service.create_put_presigned_url(
s3_bucket_name=self.staging_bucket,
file_key=file_key,
require_autodelete=False,
)
presigned_id = f"upload/{upload_id}"
deletion_date = datetime.now(timezone.utc)
Expand Down
Loading
Loading