Skip to content

Commit d02d4e7

Browse files
committed
Merge branch 'feat/multimodal_page_boundary_detection' into 'develop'
Feat/multimodal page boundary detection: Add boundary metadata transfer in multimodal page boundary classification See merge request genaiic-reusable-assets/engagement-artifacts/genaiic-idp-accelerator!238
2 parents 7908d48 + ed537f8 commit d02d4e7

File tree

10 files changed

+2822
-197
lines changed

10 files changed

+2822
-197
lines changed

.gitignore

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,4 @@ __pycache__
1717
.ruff_cache
1818
.kiro
1919
rvl_cdip_*
20-
21-
# IDE specific files
22-
.idea/
23-
20+
notebooks/examples/data

config_library/pattern-2/lending-package-sample/config_multimodal_page_boundary.yaml

Lines changed: 1425 additions & 0 deletions
Large diffs are not rendered by default.

lib/idp_common_pkg/idp_common/classification/service.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ClassificationService:
4646
# Classification method options
4747
MULTIMODAL_PAGE_LEVEL = "multimodalPageLevelClassification"
4848
TEXTBASED_HOLISTIC = "textbasedHolisticClassification"
49+
MULTIMODAL_PAGE_BOUNDARY = "multimodalPageBoundaryClassification"
4950

5051
def __init__(
5152
self,
@@ -132,6 +133,8 @@ def __init__(
132133
# Log classification method
133134
if self.classification_method == self.TEXTBASED_HOLISTIC:
134135
logger.info("Using textbased holistic packet classification method")
136+
elif self.classification_method == self.MULTIMODAL_PAGE_BOUNDARY:
137+
logger.info("Using multimodal page boundary classification method")
135138
else:
136139
# Default to multimodal page-level classification if value is invalid
137140
if self.classification_method != self.MULTIMODAL_PAGE_LEVEL:
@@ -678,16 +681,21 @@ def classify_page_bedrock(
678681
)
679682
if isinstance(classification_data, dict):
680683
doc_type = classification_data.get("class", "")
681-
logger.debug(
684+
document_boundary = classification_data.get(
685+
"document_boundary", "continue"
686+
)
687+
logger.info(
682688
f"Parsed classification response as {detected_format}: {classification_data}"
683689
)
684690
else:
685691
# If parsing failed, try to extract classification directly from text
686692
doc_type = self._extract_class_from_text(classification_text)
693+
document_boundary = "continue"
687694
except Exception as e:
688695
logger.warning(f"Failed to parse structured data from response: {e}")
689696
# Try to extract classification directly from text
690697
doc_type = self._extract_class_from_text(classification_text)
698+
document_boundary = "continue"
691699

692700
# Validate classification against known document types
693701
if not doc_type:
@@ -710,7 +718,10 @@ def classify_page_bedrock(
710718
classification=DocumentClassification(
711719
doc_type=doc_type,
712720
confidence=1.0, # Default confidence
713-
metadata={"metering": metering},
721+
metadata={
722+
"metering": metering,
723+
"document_boundary": str(document_boundary).lower(),
724+
},
714725
),
715726
image_uri=image_uri,
716727
text_uri=text_uri,
@@ -803,7 +814,10 @@ def classify_page_sagemaker(
803814
classification=DocumentClassification(
804815
doc_type=doc_type,
805816
confidence=1.0, # Default confidence since SageMaker doesn't provide it
806-
metadata={"metering": metering},
817+
metadata={
818+
"metering": metering,
819+
"document_boundary": "continue",
820+
},
807821
),
808822
image_uri=image_uri,
809823
text_uri=text_uri,
@@ -1199,10 +1213,15 @@ def classify_document(self, document: Document) -> Document:
11991213
)
12001214
return self.holistic_classify_document(document)
12011215

1202-
# Default to page-by-page classification
1216+
# Page-level classification (with or without boundary detection)
12031217
t0 = time.time()
1218+
method_desc = (
1219+
"page boundary"
1220+
if self.classification_method == self.MULTIMODAL_PAGE_BOUNDARY
1221+
else "page-by-page"
1222+
)
12041223
logger.info(
1205-
f"Classifying document with {len(document.pages)} pages using page-by-page method with {self.backend} backend"
1224+
f"Classifying document with {len(document.pages)} pages using {method_desc} method with {self.backend} backend"
12061225
)
12071226

12081227
try:
@@ -1230,6 +1249,19 @@ def classify_document(self, document: Document) -> Document:
12301249
page_id
12311250
].confidence = cached_result.classification.confidence
12321251

1252+
# Copy metadata (including boundary information) to the page
1253+
if hasattr(document.pages[page_id], "metadata"):
1254+
document.pages[
1255+
page_id
1256+
].metadata = cached_result.classification.metadata
1257+
else:
1258+
# If the page doesn't have a metadata attribute, add it
1259+
setattr(
1260+
document.pages[page_id],
1261+
"metadata",
1262+
cached_result.classification.metadata,
1263+
)
1264+
12331265
# Merge cached metering data
12341266
page_metering = cached_result.classification.metadata.get(
12351267
"metering", {}
@@ -1278,6 +1310,19 @@ def classify_document(self, document: Document) -> Document:
12781310
page_id
12791311
].confidence = page_result.classification.confidence
12801312

1313+
# Copy metadata (including boundary information) to the page
1314+
if hasattr(document.pages[page_id], "metadata"):
1315+
document.pages[
1316+
page_id
1317+
].metadata = page_result.classification.metadata
1318+
else:
1319+
# If the page doesn't have a metadata attribute, add it
1320+
setattr(
1321+
document.pages[page_id],
1322+
"metadata",
1323+
page_result.classification.metadata,
1324+
)
1325+
12811326
# Merge metering data
12821327
page_metering = page_result.classification.metadata.get(
12831328
"metering", {}
@@ -1360,7 +1405,13 @@ def classify_document(self, document: Document) -> Document:
13601405
current_pages = [sorted_results[0]]
13611406

13621407
for result in sorted_results[1:]:
1363-
if result.classification.doc_type == current_type:
1408+
boundary = result.classification.metadata.get(
1409+
"document_boundary", "continue"
1410+
).lower()
1411+
if (
1412+
result.classification.doc_type == current_type
1413+
and boundary != "start"
1414+
):
13641415
current_pages.append(result)
13651416
else:
13661417
# Create a new section with the current group of pages
@@ -1528,7 +1579,10 @@ def _group_consecutive_pages(
15281579
current_pages = [sorted_results[0]]
15291580

15301581
for result in sorted_results[1:]:
1531-
if result.classification.doc_type == current_type:
1582+
boundary = result.classification.metadata.get(
1583+
"document_boundary", "continue"
1584+
).lower()
1585+
if result.classification.doc_type == current_type and boundary != "start":
15321586
current_pages.append(result)
15331587
else:
15341588
# Create a section with the current group

lib/idp_common_pkg/tests/unit/classification/test_classification_service.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def test_classify_page_bedrock_success(
255255
assert result.classification.doc_type == "invoice"
256256
assert result.classification.confidence == 1.0
257257
assert result.classification.metadata["metering"] == {"tokens": 100}
258+
assert result.classification.metadata["document_boundary"] == "continue"
258259
assert result.image_uri == "s3://bucket/image.jpg"
259260
assert result.text_uri == "s3://bucket/text.txt"
260261

@@ -801,3 +802,34 @@ def test_holistic_classify_document_multiple_segments(
801802
assert result.pages["1"].classification == "invoice"
802803
assert result.pages["2"].classification == "receipt"
803804
assert result.pages["3"].classification == "receipt"
805+
806+
def test_group_consecutive_pages_with_boundary(self, service):
807+
"""Pages with boundary flag start new sections even with same doc type."""
808+
results = [
809+
PageClassification(
810+
page_id="1",
811+
classification=DocumentClassification(
812+
doc_type="invoice",
813+
metadata={"document_boundary": "start"},
814+
),
815+
),
816+
PageClassification(
817+
page_id="2",
818+
classification=DocumentClassification(
819+
doc_type="invoice",
820+
metadata={"document_boundary": "continue"},
821+
),
822+
),
823+
PageClassification(
824+
page_id="3",
825+
classification=DocumentClassification(
826+
doc_type="invoice",
827+
metadata={"document_boundary": "start"},
828+
),
829+
),
830+
]
831+
832+
sections = service._group_consecutive_pages(results)
833+
assert len(sections) == 2
834+
assert [p.page_id for p in sections[0].pages] == ["1", "2"]
835+
assert [p.page_id for p in sections[1].pages] == ["3"]

0 commit comments

Comments
 (0)