From 3be1ed988b17b62801a2f1dddb48b76fd3065f9e Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Wed, 11 Feb 2026 17:09:08 -0800 Subject: [PATCH] Update CloudTrail validate-logs for full key query range --- .../customizations/cloudtrail/validation.py | 12 +++- .../cloudtrail/test_validation.py | 64 +++++++++++-------- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/awscli/customizations/cloudtrail/validation.py b/awscli/customizations/cloudtrail/validation.py index 50ca8a7d21db..468f1ca8fbd7 100644 --- a/awscli/customizations/cloudtrail/validation.py +++ b/awscli/customizations/cloudtrail/validation.py @@ -322,7 +322,9 @@ def load_all_digest_keys_in_range( s3_digest_files_prefix = self._create_digest_prefix(start_date, prefix) client = self._client_provider.get_client(bucket) paginator = client.get_paginator('list_objects') - page_iterator = paginator.paginate(Bucket=bucket, Marker=marker, Prefix=s3_digest_files_prefix) + page_iterator = paginator.paginate( + Bucket=bucket, Marker=marker, Prefix=s3_digest_files_prefix + ) key_filter = page_iterator.search('Contents[*].Key') # Create a target start end end date target_start_date = format_date(normalize_date(start_date)) @@ -466,7 +468,7 @@ def _create_digest_prefix(self, start_date, key_prefix): template = 'AWSLogs/' template_params = { 'account_id': self.account_id, - 'source_region': self.trail_source_region + 'source_region': self.trail_source_region, } if self.organization_id: template += '{organization_id}/' @@ -580,7 +582,11 @@ def traverse_digests(self, start_date, end_date=None, is_backfill=False): # For regular digests, pre-load public keys. For backfill, start with empty dict public_keys = ( - {} if is_backfill else self._load_public_keys(start_date, end_date) + {} + if is_backfill + else self._load_public_keys( + start_date, end_date + timedelta(hours=2) + ) ) yield from self._traverse_digest_chain( diff --git a/tests/unit/customizations/cloudtrail/test_validation.py b/tests/unit/customizations/cloudtrail/test_validation.py index 60f79bc12987..99aed598f47e 100644 --- a/tests/unit/customizations/cloudtrail/test_validation.py +++ b/tests/unit/customizations/cloudtrail/test_validation.py @@ -690,16 +690,18 @@ def test_calls_list_objects_correctly(self): mock_search = mock_paginate.return_value.search mock_search.return_value = [] provider = self._get_mock_provider(s3_client) - provider.load_digest_keys_in_range( - '1', 'prefix', START_DATE, END_DATE) - marker = ('prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1/' - '2014/08/09/{account}_CloudTrail-Digest_us-east-1_foo_' - 'us-east-1_20140809T235900Z.json.gz') + provider.load_digest_keys_in_range('1', 'prefix', START_DATE, END_DATE) + marker = ( + 'prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1/' + '2014/08/09/{account}_CloudTrail-Digest_us-east-1_foo_' + 'us-east-1_20140809T235900Z.json.gz' + ) prefix = 'prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1' mock_paginate.assert_called_once_with( Bucket='1', Marker=marker.format(account=TEST_ACCOUNT_ID), - Prefix=prefix.format(account=TEST_ACCOUNT_ID)) + Prefix=prefix.format(account=TEST_ACCOUNT_ID), + ) def test_calls_list_objects_correctly_org_trails(self): s3_client = mock.Mock() @@ -731,52 +733,60 @@ def test_calls_list_objects_correctly_org_trails(self): Bucket='1', Marker=marker.format( member_account=TEST_ORGANIZATION_ACCOUNT_ID, - organization_id=TEST_ORGANIZATION_ID + organization_id=TEST_ORGANIZATION_ID, ), Prefix=prefix.format( member_account=TEST_ORGANIZATION_ACCOUNT_ID, - organization_id=TEST_ORGANIZATION_ID - ) + organization_id=TEST_ORGANIZATION_ID, + ), ) def test_create_digest_prefix_without_key_prefix(self): mock_s3_client_provider = mock.Mock() provider = DigestProvider( - mock_s3_client_provider, TEST_ACCOUNT_ID, 'foo', 'us-east-1') + mock_s3_client_provider, TEST_ACCOUNT_ID, 'foo', 'us-east-1' + ) prefix = provider._create_digest_prefix(START_DATE, None) - expected = 'AWSLogs/{account}/CloudTrail-Digest/us-east-1'.format( - account=TEST_ACCOUNT_ID) + expected = f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1' self.assertEqual(expected, prefix) def test_create_digest_prefix_with_key_prefix(self): mock_s3_client_provider = mock.Mock() provider = DigestProvider( - mock_s3_client_provider, TEST_ACCOUNT_ID, 'foo', 'us-east-1') + mock_s3_client_provider, TEST_ACCOUNT_ID, 'foo', 'us-east-1' + ) prefix = provider._create_digest_prefix(START_DATE, 'my-prefix') - expected = 'my-prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1'.format( - account=TEST_ACCOUNT_ID) + expected = ( + f'my-prefix/AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1' + ) self.assertEqual(expected, prefix) def test_create_digest_prefix_org_trail(self): mock_s3_client_provider = mock.Mock() provider = DigestProvider( - mock_s3_client_provider, TEST_ORGANIZATION_ACCOUNT_ID, - 'foo', 'us-east-1', 'us-east-1', TEST_ORGANIZATION_ID) + mock_s3_client_provider, + TEST_ORGANIZATION_ACCOUNT_ID, + 'foo', + 'us-east-1', + 'us-east-1', + TEST_ORGANIZATION_ID, + ) prefix = provider._create_digest_prefix(START_DATE, None) - expected = 'AWSLogs/{org}/{account}/CloudTrail-Digest/us-east-1'.format( - org=TEST_ORGANIZATION_ID, - account=TEST_ORGANIZATION_ACCOUNT_ID) + expected = f'AWSLogs/{TEST_ORGANIZATION_ID}/{TEST_ORGANIZATION_ACCOUNT_ID}/CloudTrail-Digest/us-east-1' self.assertEqual(expected, prefix) def test_create_digest_prefix_org_trail_with_key_prefix(self): mock_s3_client_provider = mock.Mock() provider = DigestProvider( - mock_s3_client_provider, TEST_ORGANIZATION_ACCOUNT_ID, - 'foo', 'us-east-1', 'us-east-1', TEST_ORGANIZATION_ID) + mock_s3_client_provider, + TEST_ORGANIZATION_ACCOUNT_ID, + 'foo', + 'us-east-1', + 'us-east-1', + TEST_ORGANIZATION_ID, + ) prefix = provider._create_digest_prefix(START_DATE, 'custom-prefix') - expected = 'custom-prefix/AWSLogs/{org}/{account}/CloudTrail-Digest/us-east-1'.format( - org=TEST_ORGANIZATION_ID, - account=TEST_ORGANIZATION_ACCOUNT_ID) + expected = f'custom-prefix/AWSLogs/{TEST_ORGANIZATION_ID}/{TEST_ORGANIZATION_ACCOUNT_ID}/CloudTrail-Digest/us-east-1' self.assertEqual(expected, prefix) def test_ensures_digest_has_proper_metadata(self): @@ -996,7 +1006,9 @@ def test_ensures_public_keys_are_loaded(self): digest_iter = traverser.traverse_digests(start_date, end_date) with self.assertRaises(RuntimeError): next(digest_iter) - key_provider.get_public_keys.assert_called_with(start_date, end_date) + key_provider.get_public_keys.assert_called_with( + start_date, end_date + timedelta(hours=2) + ) def test_ensures_public_key_is_found(self): start_date = START_DATE