Skip to content

Commit 1c3397d

Browse files
committed
typed metadata model
1 parent 2bec34e commit 1c3397d

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

lib/idp_common_pkg/idp_common/assessment/granular_service.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
X_AWS_IDP_CONFIDENCE_THRESHOLD,
3131
X_AWS_IDP_DOCUMENT_TYPE,
3232
)
33+
from idp_common.extraction.models import ExtractionData, ExtractionMetadata
3334
from idp_common.models import Document, Status
34-
from idp_common.utils import check_token_limit, grid_overlay
35+
from idp_common.utils import check_token_limit
36+
from idp_common.utils.grid_overlay import add_ruler_edges
3537

3638
logger = logging.getLogger(__name__)
3739

@@ -869,16 +871,10 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
869871
Document: Updated Document object with assessment results appended to extraction results
870872
"""
871873
# Check if assessment is enabled in typed configuration
872-
enabled = self.config.assessment.enabled
873-
if not enabled:
874+
if not self.config.assessment.enabled:
874875
logger.info("Assessment is disabled via configuration")
875876
return document
876877

877-
# Validate input document
878-
if not document:
879-
logger.error("No document provided")
880-
return document
881-
882878
if not document.sections:
883879
logger.error("Document has no sections to process")
884880
document.errors.append("Document has no sections to process")
@@ -931,8 +927,9 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
931927
try:
932928
# Read existing extraction results
933929
t0 = time.time()
934-
extraction_data = s3.get_json_content(section.extraction_result_uri)
935-
extraction_results = extraction_data.get("inference_result", {})
930+
extraction_data_dict = s3.get_json_content(section.extraction_result_uri)
931+
extraction_data = ExtractionData.model_validate(extraction_data_dict)
932+
extraction_results = extraction_data.inference_result
936933

937934
# Skip assessment if no extraction results found
938935
if not extraction_results:
@@ -997,12 +994,6 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
997994
t4 = time.time()
998995
logger.info(f"Time taken to read raw OCR results: {t4 - t3:.2f} seconds")
999996

1000-
# Get assessment configuration (type-safe, Pydantic handles conversions)
1001-
model_id = self.config.assessment.model
1002-
temperature = self.config.assessment.temperature
1003-
max_tokens = self.config.assessment.max_tokens
1004-
system_prompt = self.config.assessment.system_prompt
1005-
1006997
# Get schema for this document class
1007998
class_schema = self._get_class_schema(class_label)
1008999
if not class_schema:
@@ -1053,7 +1044,7 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
10531044
# Apply grid overlay to page images for assessment
10541045
grid_page_images = []
10551046
for page_img in page_images:
1056-
grid_img = grid_overlay.add_grid_overlay(page_img)
1047+
grid_img = add_ruler_edges(page_img)
10571048
grid_page_images.append(grid_img)
10581049

10591050
# Execute tasks using Strands-based parallel executor
@@ -1070,10 +1061,10 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
10701061
extraction_results=extraction_results,
10711062
page_images=grid_page_images,
10721063
sorted_page_ids=sorted_page_ids,
1073-
model_id=model_id,
1074-
system_prompt=system_prompt,
1075-
temperature=temperature,
1076-
max_tokens=max_tokens,
1064+
model_id=self.config.assessment.model,
1065+
system_prompt=self.config.assessment.system_prompt,
1066+
temperature=self.config.assessment.temperature,
1067+
max_tokens=self.config.assessment.max_tokens,
10771068
max_concurrent=self.max_workers,
10781069
)
10791070
)
@@ -1232,21 +1223,21 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
12321223
f"Document will be marked as failed without retry."
12331224
)
12341225

1235-
# Update the existing extraction result with enhanced assessment data
1236-
extraction_data["explainability_info"] = [enhanced_assessment_data]
1237-
extraction_data["metadata"] = extraction_data.get("metadata", {})
1238-
extraction_data["metadata"]["assessment_time_seconds"] = total_duration
1239-
extraction_data["metadata"]["granular_assessment_used"] = True
1240-
extraction_data["metadata"]["assessment_tasks_total"] = len(tasks)
1241-
extraction_data["metadata"]["assessment_tasks_successful"] = len(
1242-
successful_tasks
1243-
)
1244-
extraction_data["metadata"]["assessment_tasks_failed"] = len(failed_tasks)
1226+
# Update the existing extraction result with enhanced assessment data (typed)
1227+
extraction_data.explainability_info = [enhanced_assessment_data]
1228+
extraction_data.metadata.assessment_time_seconds = total_duration
1229+
extraction_data.metadata.granular_assessment_used = True
1230+
extraction_data.metadata.assessment_tasks_total = len(tasks)
1231+
extraction_data.metadata.assessment_tasks_successful = len(successful_tasks)
1232+
extraction_data.metadata.assessment_tasks_failed = len(failed_tasks)
12451233

12461234
# Write the updated result back to S3
12471235
bucket, key = utils.parse_s3_uri(section.extraction_result_uri)
12481236
s3.write_content(
1249-
extraction_data, bucket, key, content_type="application/json"
1237+
extraction_data.model_dump(mode="json"),
1238+
bucket,
1239+
key,
1240+
content_type="application/json",
12501241
)
12511242

12521243
# Update the section in the document with confidence threshold alerts

0 commit comments

Comments
 (0)