Skip to content

Commit ae458ec

Browse files
committed
Merge branch 'fix/p1-confidencealerts' into 'develop'
Fix Confidence alerts for Pattern1 See merge request genaiic-reusable-assets/engagement-artifacts/genaiic-idp-accelerator!219
2 parents 7d9daad + 619c8fe commit ae458ec

File tree

2 files changed

+85
-74
lines changed

2 files changed

+85
-74
lines changed

patterns/pattern-1/src/processresults_function/index.py

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
ssm_client = boto3.client('ssm')
3232
bedrock_client = boto3.client('bedrock-data-automation')
3333
SAGEMAKER_A2I_REVIEW_PORTAL_URL = os.environ.get('SAGEMAKER_A2I_REVIEW_PORTAL_URL', '')
34+
enable_hitl = os.environ.get('ENABLE_HITL', 'false').lower()
3435

3536
def get_confidence_threshold_from_config(document: Document) -> float:
3637
"""
@@ -43,9 +44,8 @@ def get_confidence_threshold_from_config(document: Document) -> float:
4344
float: The confidence threshold as a decimal (0.0-1.0)
4445
"""
4546
try:
46-
config = get_config(document)
47-
assessment_config = config.get('assessment', {})
48-
threshold_value = float(assessment_config.get('default_confidence_threshold', 0.8))
47+
config = get_config()
48+
threshold_value = float(config['assessment']['default_confidence_threshold'])
4949

5050
# Validate that the threshold is in the expected 0.0-1.0 range
5151
if threshold_value < 0.0 or threshold_value > 1.0:
@@ -870,12 +870,17 @@ def process_segments(
870870
bp_confidence = custom_output["matched_blueprint"]["confidence"]
871871

872872
# Check if any key-value or blueprint confidence is below threshold
873-
low_confidence = any(
874-
kv['confidence'] < confidence_threshold
875-
for page_num in page_indices
876-
for kv in pagespecific_details['key_value_details'].get(str(page_num), [])
877-
) or float(bp_confidence) < confidence_threshold
873+
if enable_hitl == 'true':
874+
low_confidence = any(
875+
kv['confidence'] < confidence_threshold
876+
for page_num in page_indices
877+
for kv in pagespecific_details['key_value_details'].get(str(page_num), [])
878+
) or float(bp_confidence) < confidence_threshold
879+
else:
880+
low_confidence = None
878881

882+
logger.info(f"low_confidence: {low_confidence}")
883+
879884
item.update({
880885
"page_array": page_indices,
881886
"hitl_triggered": low_confidence,
@@ -894,7 +899,7 @@ def process_segments(
894899
)
895900

896901
if low_confidence:
897-
hitl_triggered = True
902+
hitl_triggered = low_confidence
898903
metrics.put_metric('HITLTriggered', 1)
899904
for page_number in page_indices:
900905
page_str = str(page_number)
@@ -922,6 +927,10 @@ def process_segments(
922927
"hitl_corrected_result": custom_decimal_output
923928
})
924929
else:
930+
if enable_hitl == 'true':
931+
std_hitl = 'true'
932+
else:
933+
std_hitl = None
925934
# Process standard output if no custom output match
926935
std_bucket, std_key = parse_s3_path(segment['standard_output_path'])
927936
std_output = download_decimal(std_bucket, std_key)
@@ -931,7 +940,7 @@ def process_segments(
931940
page_array = list(range(start_page, end_page + 1))
932941
item.update({
933942
"page_array": page_array,
934-
"hitl_triggered": True,
943+
"hitl_triggered": std_hitl,
935944
"extraction_bp_name": "None",
936945
"extracted_result": std_output
937946
})
@@ -941,30 +950,31 @@ def process_segments(
941950
record_number=record_number,
942951
bp_match=segment.get('custom_output_status'),
943952
extraction_bp_name="None",
944-
hitl_triggered=True,
953+
hitl_triggered=std_hitl,
945954
page_array=page_array,
946955
review_portal_url=SAGEMAKER_A2I_REVIEW_PORTAL_URL
947956
)
948957

949-
hitl_triggered = True
950-
for page_number in range(start_page, end_page + 1):
951-
ImageUri = f"s3://{output_bucket}/{object_key}/pages/{page_number}/image.jpg"
952-
try:
953-
human_loop_response = start_human_loop(
954-
execution_id=execution_id,
955-
kv_pairs=[],
956-
source_image_uri=ImageUri,
957-
bounding_boxes=[],
958-
blueprintName="",
959-
bp_confidence=0.00,
960-
confidenceThreshold=confidence_threshold,
961-
page_id=page_number,
962-
page_indices=page_array,
963-
record_number=record_number
964-
)
965-
logger.info(f"Triggered human loop for page {page_number}: {human_loop_response}")
966-
except Exception as e:
967-
logger.error(f"Failed to start human loop for page {page_number}: {str(e)}")
958+
hitl_triggered = std_hitl
959+
if enable_hitl == 'true':
960+
for page_number in range(start_page, end_page + 1):
961+
ImageUri = f"s3://{output_bucket}/{object_key}/pages/{page_number}/image.jpg"
962+
try:
963+
human_loop_response = start_human_loop(
964+
execution_id=execution_id,
965+
kv_pairs=[],
966+
source_image_uri=ImageUri,
967+
bounding_boxes=[],
968+
blueprintName="",
969+
bp_confidence=0.00,
970+
confidenceThreshold=confidence_threshold,
971+
page_id=page_number,
972+
page_indices=page_array,
973+
record_number=record_number
974+
)
975+
logger.info(f"Triggered human loop for page {page_number}: {human_loop_response}")
976+
except Exception as e:
977+
logger.error(f"Failed to start human loop for page {page_number}: {str(e)}")
968978

969979
document.hitl_metadata.append(hitl_metadata)
970980

@@ -1102,53 +1112,51 @@ def handler(event, context):
11021112

11031113
# Process HITL if enabled
11041114
hitl_triggered = "false"
1105-
enable_hitl = os.environ.get('ENABLE_HITL', 'false').lower() == 'true'
11061115

1107-
if enable_hitl:
1108-
try:
1109-
# Use the confidence threshold already calculated above
1110-
metdatafile_path = '/'.join(bda_result_prefix.split('/')[:-1])
1111-
job_metadata_key = f'{metdatafile_path}/job_metadata.json'
1112-
execution_id = event.get("execution_arn", "").split(':')[-1]
1113-
logger.info(f"HITL execution ID: {execution_id}")
1116+
try:
1117+
# Use the confidence threshold already calculated above
1118+
metdatafile_path = '/'.join(bda_result_prefix.split('/')[:-1])
1119+
job_metadata_key = f'{metdatafile_path}/job_metadata.json'
1120+
execution_id = event.get("execution_arn", "").split(':')[-1]
1121+
logger.info(f"HITL execution ID: {execution_id}")
11141122

1115-
try:
1116-
jobmetadata_file = s3_client.get_object(Bucket=bda_result_bucket, Key=job_metadata_key)
1117-
job_metadata = json.loads(jobmetadata_file['Body'].read())
1118-
if 'output_metadata' in job_metadata:
1119-
output_metadata = job_metadata['output_metadata']
1120-
if isinstance(output_metadata, list):
1121-
for asset in output_metadata:
1122-
document, hitl_result = process_segments(
1123-
input_bucket,
1124-
output_bucket,
1125-
object_key,
1126-
asset.get('segment_metadata', []),
1127-
confidence_threshold,
1128-
execution_id,
1129-
document
1130-
)
1131-
if hitl_result:
1132-
hitl_triggered = "true"
1133-
elif isinstance(output_metadata, dict):
1134-
for asset_id, asset in output_metadata.items():
1135-
document, hitl_result = process_segments(
1136-
input_bucket,
1137-
output_bucket,
1138-
object_key,
1139-
asset.get('segment_metadata', []),
1140-
confidence_threshold,
1141-
execution_id,
1142-
document
1143-
)
1144-
if hitl_result:
1145-
hitl_triggered = "true"
1146-
else:
1147-
logger.error("Unexpected output_metadata format in job_metadata.json")
1148-
except Exception as e:
1149-
logger.error(f"Error processing job_metadata.json: {str(e)}")
1123+
try:
1124+
jobmetadata_file = s3_client.get_object(Bucket=bda_result_bucket, Key=job_metadata_key)
1125+
job_metadata = json.loads(jobmetadata_file['Body'].read())
1126+
if 'output_metadata' in job_metadata:
1127+
output_metadata = job_metadata['output_metadata']
1128+
if isinstance(output_metadata, list):
1129+
for asset in output_metadata:
1130+
document, hitl_result = process_segments(
1131+
input_bucket,
1132+
output_bucket,
1133+
object_key,
1134+
asset.get('segment_metadata', []),
1135+
confidence_threshold,
1136+
execution_id,
1137+
document
1138+
)
1139+
if hitl_result:
1140+
hitl_triggered = "true"
1141+
elif isinstance(output_metadata, dict):
1142+
for asset_id, asset in output_metadata.items():
1143+
document, hitl_result = process_segments(
1144+
input_bucket,
1145+
output_bucket,
1146+
object_key,
1147+
asset.get('segment_metadata', []),
1148+
confidence_threshold,
1149+
execution_id,
1150+
document
1151+
)
1152+
if hitl_result:
1153+
hitl_triggered = "true"
1154+
else:
1155+
logger.error("Unexpected output_metadata format in job_metadata.json")
11501156
except Exception as e:
1151-
logger.error(f"Error in HITL processing: {str(e)}")
1157+
logger.error(f"Error processing job_metadata.json: {str(e)}")
1158+
except Exception as e:
1159+
logger.error(f"Error in HITL processing: {str(e)}")
11521160

11531161
# Record metrics for processed pages
11541162
metrics.put_metric('ProcessedDocuments', 1)

patterns/pattern-1/template.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ Resources:
599599
BDA_PROJECT_ARN: !Ref BDAProjectArn
600600
WORKING_BUCKET: !Ref WorkingBucket
601601
SAGEMAKER_A2I_REVIEW_PORTAL_URL: !Ref SageMakerA2IReviewPortalURL
602+
CONFIGURATION_TABLE_NAME: !Ref ConfigurationTable
602603
LoggingConfig:
603604
LogGroup: !Ref ProcessResultsFunctionLogGroup
604605
Policies:
@@ -611,6 +612,8 @@ Resources:
611612
BucketName: !Ref OutputBucket
612613
- DynamoDBCrudPolicy:
613614
TableName: !Ref BDAMetadataTable
615+
- DynamoDBCrudPolicy:
616+
TableName: !Ref ConfigurationTable
614617
- Statement:
615618
- Effect: Allow
616619
Action: cloudwatch:PutMetricData

0 commit comments

Comments
 (0)