Skip to content

Commit 41a8b39

Browse files
committed
typed metadata model
1 parent ab0fd51 commit 41a8b39

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

lib/idp_common_pkg/idp_common/assessment/granular_service.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
X_AWS_IDP_CONFIDENCE_THRESHOLD,
3131
X_AWS_IDP_DOCUMENT_TYPE,
3232
)
33+
from idp_common.extraction.models import ExtractionData, ExtractionMetadata
3334
from idp_common.models import Document, Status
34-
from idp_common.utils import check_token_limit, grid_overlay
35+
from idp_common.utils import check_token_limit
36+
from idp_common.utils.grid_overlay import add_ruler_edges
3537

3638
logger = logging.getLogger(__name__)
3739

@@ -880,16 +882,10 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
880882
Document: Updated Document object with assessment results appended to extraction results
881883
"""
882884
# Check if assessment is enabled in typed configuration
883-
enabled = self.config.assessment.enabled
884-
if not enabled:
885+
if not self.config.assessment.enabled:
885886
logger.info("Assessment is disabled via configuration")
886887
return document
887888

888-
# Validate input document
889-
if not document:
890-
logger.error("No document provided")
891-
return document
892-
893889
if not document.sections:
894890
logger.error("Document has no sections to process")
895891
document.errors.append("Document has no sections to process")
@@ -942,8 +938,9 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
942938
try:
943939
# Read existing extraction results
944940
t0 = time.time()
945-
extraction_data = s3.get_json_content(section.extraction_result_uri)
946-
extraction_results = extraction_data.get("inference_result", {})
941+
extraction_data_dict = s3.get_json_content(section.extraction_result_uri)
942+
extraction_data = ExtractionData.model_validate(extraction_data_dict)
943+
extraction_results = extraction_data.inference_result
947944

948945
# Skip assessment if no extraction results found
949946
if not extraction_results:
@@ -1008,12 +1005,6 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
10081005
t4 = time.time()
10091006
logger.info(f"Time taken to read raw OCR results: {t4 - t3:.2f} seconds")
10101007

1011-
# Get assessment configuration (type-safe, Pydantic handles conversions)
1012-
model_id = self.config.assessment.model
1013-
temperature = self.config.assessment.temperature
1014-
max_tokens = self.config.assessment.max_tokens
1015-
system_prompt = self.config.assessment.system_prompt
1016-
10171008
# Get schema for this document class
10181009
class_schema = self._get_class_schema(class_label)
10191010
if not class_schema:
@@ -1064,7 +1055,7 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
10641055
# Apply grid overlay to page images for assessment
10651056
grid_page_images = []
10661057
for page_img in page_images:
1067-
grid_img = grid_overlay.add_grid_overlay(page_img)
1058+
grid_img = add_ruler_edges(page_img)
10681059
grid_page_images.append(grid_img)
10691060

10701061
# Execute tasks using Strands-based parallel executor
@@ -1081,10 +1072,10 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
10811072
extraction_results=extraction_results,
10821073
page_images=grid_page_images,
10831074
sorted_page_ids=sorted_page_ids,
1084-
model_id=model_id,
1085-
system_prompt=system_prompt,
1086-
temperature=temperature,
1087-
max_tokens=max_tokens,
1075+
model_id=self.config.assessment.model,
1076+
system_prompt=self.config.assessment.system_prompt,
1077+
temperature=self.config.assessment.temperature,
1078+
max_tokens=self.config.assessment.max_tokens,
10881079
max_concurrent=self.max_workers,
10891080
)
10901081
)
@@ -1243,21 +1234,21 @@ def process_document_section(self, document: Document, section_id: str) -> Docum
12431234
f"Document will be marked as failed without retry."
12441235
)
12451236

1246-
# Update the existing extraction result with enhanced assessment data
1247-
extraction_data["explainability_info"] = [enhanced_assessment_data]
1248-
extraction_data["metadata"] = extraction_data.get("metadata", {})
1249-
extraction_data["metadata"]["assessment_time_seconds"] = total_duration
1250-
extraction_data["metadata"]["granular_assessment_used"] = True
1251-
extraction_data["metadata"]["assessment_tasks_total"] = len(tasks)
1252-
extraction_data["metadata"]["assessment_tasks_successful"] = len(
1253-
successful_tasks
1254-
)
1255-
extraction_data["metadata"]["assessment_tasks_failed"] = len(failed_tasks)
1237+
# Update the existing extraction result with enhanced assessment data (typed)
1238+
extraction_data.explainability_info = [enhanced_assessment_data]
1239+
extraction_data.metadata.assessment_time_seconds = total_duration
1240+
extraction_data.metadata.granular_assessment_used = True
1241+
extraction_data.metadata.assessment_tasks_total = len(tasks)
1242+
extraction_data.metadata.assessment_tasks_successful = len(successful_tasks)
1243+
extraction_data.metadata.assessment_tasks_failed = len(failed_tasks)
12561244

12571245
# Write the updated result back to S3
12581246
bucket, key = utils.parse_s3_uri(section.extraction_result_uri)
12591247
s3.write_content(
1260-
extraction_data, bucket, key, content_type="application/json"
1248+
extraction_data.model_dump(mode="json"),
1249+
bucket,
1250+
key,
1251+
content_type="application/json",
12611252
)
12621253

12631254
# Update the section in the document with confidence threshold alerts

0 commit comments

Comments
 (0)