Skip to content

Commit 6596fa6

Browse files
author
Taniya Mathur
committed
Add test set bucket pattern match, observability improvements, and caching for test results
1 parent ee3c034 commit 6596fa6

File tree

12 files changed

+409
-120
lines changed

12 files changed

+409
-120
lines changed

docs/test-studio.md

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,21 @@ The Test Studio consists of two main tabs:
3434

3535
#### TestResultsResolver Lambda
3636
- **Location**: `src/lambda/test_results_resolver/index.py`
37-
- **Purpose**: Handles GraphQL queries for test results and comparisons
38-
- **Features**: Result retrieval, comparison logic, metrics aggregation
37+
- **Purpose**: Handles GraphQL queries for test results and comparisons, plus asynchronous cache updates
38+
- **Features**:
39+
- Result retrieval with cached metrics
40+
- Comparison logic and metrics aggregation
41+
- Dual event handling (GraphQL + SQS)
42+
- Asynchronous cache update processing
43+
- Progress-aware status updates
44+
45+
#### TestResultCacheUpdateQueue
46+
- **Type**: AWS SQS Queue
47+
- **Purpose**: Decouples heavy metric calculations from synchronous API calls
48+
- **Features**:
49+
- Encrypted message storage
50+
- 15-minute visibility timeout for long-running calculations
51+
- Automatic retry handling
3952

4053
### GraphQL Schema
4154
- **Location**: `src/api/schema.graphql`
@@ -77,7 +90,9 @@ components/
7790
## Test Sets
7891

7992
### Creating Test Sets
80-
1. **Pattern-based**: Define file patterns (e.g., `*.pdf`)
93+
1. **Pattern-based**: Define file patterns (e.g., `*.pdf`) with bucket type selection
94+
- **Input Bucket**: Scan main processing bucket for matching files
95+
- **Test Set Bucket**: Scan dedicated test set bucket for matching files
8196
2. **Zip Upload**: Upload zip containing `input/` and `baseline/` folders
8297
3. **Direct Upload**: Files uploaded directly to TestSetBucket are auto-detected
8398

@@ -126,7 +141,8 @@ my-test-set/
126141
## Key Features
127142

128143
### Test Set Management
129-
- Reusable collections with file patterns
144+
- Reusable collections with file patterns across multiple buckets
145+
- Dual bucket support (Input Bucket and Test Set Bucket)
130146
- Zip upload with automatic extraction
131147
- Direct upload detection via dual polling
132148
- File structure validation with error reporting

lib/idp_common_pkg/tests/unit/test_test_set_resolver.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def test_add_test_set_structure(self, mock_boto3, mock_datetime, mock_uuid):
7373
mock_boto3.return_value = mock_sqs
7474

7575
with patch.object(test_set_index.db_client, "put_item") as mock_put:
76-
args = {"name": "test", "filePattern": "*.pdf", "fileCount": 5}
76+
args = {
77+
"name": "test",
78+
"filePattern": "*.pdf",
79+
"fileCount": 5,
80+
"bucketType": "input",
81+
}
7782
result = test_set_index.add_test_set(args)
7883

7984
mock_put.assert_called_once()
@@ -126,8 +131,8 @@ def test_list_input_bucket_files(self):
126131
with patch.object(test_set_index, "find_matching_files") as mock_find:
127132
mock_find.return_value = ["file1.pdf", "file2.pdf"]
128133

129-
args = {"filePattern": "*.pdf"}
130-
result = test_set_index.list_input_bucket_files(args)
134+
args = {"filePattern": "*.pdf", "bucketType": "input"}
135+
result = test_set_index.list_bucket_files(args)
131136

132137
mock_find.assert_called_once_with("test-bucket", "*.pdf")
133138
assert result == ["file1.pdf", "file2.pdf"]

src/api/schema.graphql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ type Mutation {
431431
@aws_iam
432432
startTestRun(input: TestRunInput!): TestRun @aws_cognito_user_pools
433433
deleteTests(testRunIds: [String!]!): Boolean! @aws_cognito_user_pools
434-
addTestSet(name: String!, filePattern: String!, fileCount: Int!): TestSet @aws_cognito_user_pools
434+
addTestSet(name: String!, filePattern: String!, bucketType: String!, fileCount: Int!): TestSet @aws_cognito_user_pools
435435
addTestSetFromUpload(input: TestSetUploadInput!): TestSetUploadResponse @aws_cognito_user_pools
436436
deleteTestSets(testSetIds: [String!]!): Boolean! @aws_cognito_user_pools
437437
}
@@ -475,7 +475,7 @@ type Query @aws_cognito_user_pools @aws_iam {
475475
getTestRunStatus(testRunId: String!): TestRunStatus @aws_cognito_user_pools
476476
compareTestRuns(testRunIds: [String!]!): TestRunComparison @aws_cognito_user_pools
477477
getTestSets: [TestSet] @aws_cognito_user_pools
478-
listInputBucketFiles(filePattern: String!): [String] @aws_cognito_user_pools
478+
listBucketFiles(bucketType: String!, filePattern: String!): [String] @aws_cognito_user_pools
479479
validateTestFileName(fileName: String!): TestSetValidationResponse @aws_cognito_user_pools
480480
}
481481

src/lambda/test_results_resolver/index.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import boto3
1212

13+
sqs = boto3.client('sqs')
14+
1315

1416
# Custom JSON encoder to handle Decimal objects from DynamoDB
1517
class DecimalEncoder(json.JSONEncoder):
@@ -24,7 +26,13 @@ def default(self, obj):
2426
dynamodb = boto3.resource('dynamodb')
2527

2628
def handler(event, context):
27-
"""GraphQL resolver for test results queries"""
29+
"""Handle both GraphQL resolver and SQS events"""
30+
31+
# Check if this is an SQS event
32+
if 'Records' in event:
33+
return handle_cache_update_request(event, context)
34+
35+
# Otherwise handle as GraphQL resolver
2836
field_name = event['info']['fieldName']
2937

3038
if field_name == 'getTestRuns':
@@ -46,6 +54,52 @@ def handler(event, context):
4654

4755
raise ValueError(f"Unknown field: {field_name}")
4856

57+
def handle_cache_update_request(event, context):
58+
"""Process SQS messages to calculate and cache test result metrics"""
59+
60+
for record in event['Records']:
61+
try:
62+
message = json.loads(record['body'])
63+
test_run_id = message['testRunId']
64+
65+
logger.info(f"Processing cache update for test run: {test_run_id}")
66+
67+
# Calculate metrics
68+
aggregated_metrics = _aggregate_test_run_metrics(test_run_id)
69+
70+
# Cache the metrics
71+
metrics_to_cache = {
72+
'overallAccuracy': aggregated_metrics.get('overall_accuracy'),
73+
'weightedOverallScores': aggregated_metrics.get('weighted_overall_scores', []),
74+
'averageConfidence': aggregated_metrics.get('average_confidence'),
75+
'accuracyBreakdown': aggregated_metrics.get('accuracy_breakdown', {}),
76+
'totalCost': aggregated_metrics.get('total_cost', 0),
77+
'costBreakdown': aggregated_metrics.get('cost_breakdown', {})
78+
}
79+
80+
table = dynamodb.Table(os.environ['TRACKING_TABLE'])
81+
table.update_item(
82+
Key={'PK': f'testrun#{test_run_id}', 'SK': 'metadata'},
83+
UpdateExpression='SET testRunResult = :metrics',
84+
ExpressionAttributeValues={':metrics': float_to_decimal(metrics_to_cache)}
85+
)
86+
87+
logger.info(f"Successfully cached metrics for test run: {test_run_id}")
88+
89+
except Exception as e:
90+
logger.error(f"Failed to process cache update for {record.get('body', 'unknown')}: {e}")
91+
# Don't raise - let other messages in batch continue processing
92+
93+
def float_to_decimal(obj):
94+
"""Convert float values to Decimal for DynamoDB storage"""
95+
if isinstance(obj, float):
96+
return Decimal(str(obj))
97+
elif isinstance(obj, dict):
98+
return {k: float_to_decimal(v) for k, v in obj.items()}
99+
elif isinstance(obj, list):
100+
return [float_to_decimal(v) for v in obj]
101+
return obj
102+
49103
def compare_test_runs(test_run_ids):
50104
"""Compare multiple test runs"""
51105
logger.info(f"Comparing test runs: {test_run_ids}")
@@ -181,15 +235,6 @@ def get_test_results(test_run_id):
181235
try:
182236
logger.info(f"Caching metrics for test run: {test_run_id}")
183237

184-
def float_to_decimal(obj):
185-
if isinstance(obj, float):
186-
return Decimal(str(obj))
187-
elif isinstance(obj, dict):
188-
return {k: float_to_decimal(v) for k, v in obj.items()}
189-
elif isinstance(obj, list):
190-
return [float_to_decimal(v) for v in obj]
191-
return obj
192-
193238
# Cache only static metrics
194239
metrics_to_cache = {
195240
'overallAccuracy': aggregated_metrics.get('overall_accuracy'),
@@ -424,6 +469,20 @@ def get_test_run_status(test_run_id):
424469
}
425470
)
426471
logger.info(f"Successfully updated test run {test_run_id} status to {overall_status}")
472+
473+
# Queue metric calculation for completed test runs
474+
if overall_status in ['COMPLETE', 'PARTIAL_COMPLETE'] and not item.get('testRunResult'):
475+
try:
476+
queue_url = os.environ.get('TEST_RESULT_CACHE_UPDATE_QUEUE_URL')
477+
if queue_url:
478+
sqs.send_message(
479+
QueueUrl=queue_url,
480+
MessageBody=json.dumps({'testRunId': test_run_id})
481+
)
482+
logger.info(f"Queued cache update for test run: {test_run_id}")
483+
except Exception as e:
484+
logger.warning(f"Failed to queue cache update for {test_run_id}: {e}")
485+
427486
except Exception as e:
428487
logger.error(f"Failed to auto-update test run {test_run_id} status: {e}")
429488

src/lambda/test_set_file_copier/index.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,26 @@ def handler(event, context):
2424

2525
test_set_id = message['testSetId']
2626
file_pattern = message['filePattern']
27+
bucket_type = message['bucketType']
2728
tracking_table = message['trackingTable']
2829

2930
# Get environment variables
3031
input_bucket = os.environ['INPUT_BUCKET']
3132
test_set_bucket = os.environ['TEST_SET_BUCKET']
3233
baseline_bucket = os.environ['BASELINE_BUCKET']
3334

34-
logger.info(f"Processing test set {test_set_id} with pattern '{file_pattern}'")
35+
# Determine source bucket based on bucket type
36+
if bucket_type == 'input':
37+
source_bucket = input_bucket
38+
elif bucket_type == 'testset':
39+
source_bucket = test_set_bucket
40+
else:
41+
raise ValueError(f"Invalid bucket type: {bucket_type}")
3542

36-
# Find matching files in input bucket
37-
matching_files = find_matching_files(input_bucket, file_pattern)
43+
logger.info(f"Processing test set {test_set_id} with pattern '{file_pattern}' from {bucket_type} bucket")
44+
45+
# Find matching files in source bucket
46+
matching_files = find_matching_files(source_bucket, file_pattern)
3847

3948
if not matching_files:
4049
raise ValueError(f"No files found matching pattern: {file_pattern}")
@@ -45,9 +54,25 @@ def handler(event, context):
4554
missing_baselines = []
4655
for file_key in matching_files:
4756
try:
57+
if bucket_type == 'testset':
58+
# For testset bucket, baseline is in the same bucket under baseline/ path
59+
# Extract test set name from file path (assuming format: test_set_name/input/file)
60+
path_parts = file_key.split('/')
61+
if len(path_parts) >= 3 and path_parts[1] == 'input':
62+
test_set_name = path_parts[0]
63+
file_name = path_parts[2]
64+
baseline_prefix = f"{test_set_name}/baseline/{file_name}/"
65+
baseline_check_bucket = source_bucket
66+
else:
67+
missing_baselines.append(file_key)
68+
continue
69+
else:
70+
# For input bucket, baseline is in separate baseline bucket
71+
baseline_prefix = f"{file_key}/"
72+
baseline_check_bucket = baseline_bucket
73+
4874
# Check if baseline folder exists by listing objects with prefix
49-
baseline_prefix = f"{file_key}/"
50-
response = s3.list_objects_v2(Bucket=baseline_bucket, Prefix=baseline_prefix, MaxKeys=1)
75+
response = s3.list_objects_v2(Bucket=baseline_check_bucket, Prefix=baseline_prefix, MaxKeys=1)
5176

5277
if 'Contents' not in response or len(response['Contents']) == 0:
5378
missing_baselines.append(file_key)
@@ -60,10 +85,13 @@ def handler(event, context):
6085
raise ValueError(f"Missing baseline folders for: {', '.join(missing_baselines)}")
6186

6287
# Copy input files to test set bucket
63-
_copy_files_to_test_set(input_bucket, test_set_bucket, test_set_id, 'input', matching_files)
88+
_copy_files_to_test_set(source_bucket, test_set_bucket, test_set_id, 'input', matching_files)
6489

6590
# Copy baseline folders to test set bucket
66-
_copy_files_to_test_set(baseline_bucket, test_set_bucket, test_set_id, 'baseline', matching_files)
91+
if bucket_type == 'testset':
92+
_copy_baseline_from_testset(source_bucket, test_set_bucket, test_set_id, matching_files)
93+
else:
94+
_copy_files_to_test_set(baseline_bucket, test_set_bucket, test_set_id, 'baseline', matching_files)
6795

6896
logger.info(f"Copied {len(matching_files)} input files and {len(matching_files)} baseline folders")
6997

@@ -121,6 +149,43 @@ def _copy_files_to_test_set(source_bucket, dest_bucket, test_set_id, folder_type
121149

122150
logger.info(f"Copied {folder_type} file: {source_key} -> {dest_bucket}/{dest_key}")
123151

152+
def _copy_baseline_from_testset(source_bucket, dest_bucket, test_set_id, files):
153+
"""Copy baseline files from testset bucket where baselines are in test_set/baseline/ path"""
154+
155+
for file_key in files:
156+
# Extract test set name and file name from path (format: test_set_name/input/file_name)
157+
path_parts = file_key.split('/')
158+
if len(path_parts) >= 3 and path_parts[1] == 'input':
159+
source_test_set_name = path_parts[0]
160+
file_name = path_parts[2]
161+
162+
# Source baseline path in testset bucket
163+
source_baseline_prefix = f"{source_test_set_name}/baseline/{file_name}/"
164+
# Destination baseline path
165+
dest_baseline_prefix = f"{test_set_id}/baseline/{file_name}/"
166+
167+
# List all objects in the source baseline folder
168+
paginator = s3.get_paginator('list_objects_v2')
169+
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_baseline_prefix)
170+
171+
for page in pages:
172+
if 'Contents' in page:
173+
for obj in page['Contents']:
174+
source_key = obj['Key']
175+
# Replace the source baseline prefix with dest baseline prefix
176+
dest_key = source_key.replace(source_baseline_prefix, dest_baseline_prefix, 1)
177+
178+
# Copy file
179+
s3.copy_object(
180+
CopySource={'Bucket': source_bucket, 'Key': source_key},
181+
Bucket=dest_bucket,
182+
Key=dest_key
183+
)
184+
185+
logger.info(f"Copied testset baseline file: {source_key} -> {dest_bucket}/{dest_key}")
186+
else:
187+
logger.warning(f"Unexpected file path format for testset baseline: {file_key}")
188+
124189
def _update_test_set_status(tracking_table, test_set_id, status, error=None):
125190
"""Update test set status in tracking table"""
126191
table = dynamodb.Table(tracking_table) # type: ignore

src/lambda/test_set_resolver/index.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def handler(event, context):
3434
return delete_test_sets(event['arguments'])
3535
elif field_name == 'getTestSets':
3636
return get_test_sets()
37-
elif field_name == 'listInputBucketFiles':
38-
return list_input_bucket_files(event['arguments'])
37+
elif field_name == 'listBucketFiles':
38+
return list_bucket_files(event['arguments'])
3939
elif field_name == 'validateTestFileName':
4040
return validate_test_file_name(event['arguments'])
4141
else:
@@ -137,6 +137,7 @@ def add_test_set(args):
137137
MessageBody=json.dumps({
138138
'testSetId': test_set_id,
139139
'filePattern': args['filePattern'],
140+
'bucketType': args['bucketType'],
140141
'trackingTable': os.environ['TRACKING_TABLE']
141142
})
142143
)
@@ -451,14 +452,22 @@ def _create_test_set_tracking_entry(test_set_id, name, file_count, status, error
451452
logger.error(f"Error creating tracking entry for {test_set_id}: {str(e)}")
452453

453454

454-
def list_input_bucket_files(args):
455-
logger.info(f"Listing files with pattern: {args['filePattern']}")
455+
def list_bucket_files(args):
456+
logger.info(f"Listing files with pattern: {args['filePattern']} from bucket type: {args['bucketType']}")
456457

457458
file_pattern = args['filePattern']
458-
input_bucket = os.environ['INPUT_BUCKET']
459+
bucket_type = args['bucketType']
459460

460-
files = find_matching_files(input_bucket, file_pattern)
461-
logger.info(f"Found {len(files)} matching files")
461+
# Determine which bucket to use based on bucket type
462+
if bucket_type == 'input':
463+
bucket = os.environ['INPUT_BUCKET']
464+
elif bucket_type == 'testset':
465+
bucket = os.environ['TEST_SET_BUCKET']
466+
else:
467+
raise Exception(f"Invalid bucket type: {bucket_type}")
468+
469+
files = find_matching_files(bucket, file_pattern)
470+
logger.info(f"Found {len(files)} matching files in {bucket_type} bucket")
462471

463472
return files
464473

0 commit comments

Comments
 (0)