Skip to content

Commit 1e71750

Browse files
author
Daniel Lorch
committed
feat: dynamic-few shot Lambda using S3 Vectors
1 parent 96571e8 commit 1e71750

File tree

5 files changed

+837
-0
lines changed

5 files changed

+837
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
"""
5+
Lambda function to provide examples with ground truth data based on S3 Vectors lookup.
6+
7+
Key Features Demonstrated:
8+
- Dynamically retrieve similar examples based on document content using vector similarity search
9+
- Provide few-shot examples to improve extraction accuracy through example-based prompting
10+
- Leverage S3 Vectors for efficient similarity search across large example datasets
11+
- Integrate multimodal embeddings using Amazon Nova models for image-based similarity
12+
- Customize example selection based on document characteristics and business rules
13+
"""
14+
15+
import json
16+
import logging
17+
import base64
18+
import boto3
19+
import os
20+
21+
from idp_common import bedrock, s3
22+
23+
logger = logging.getLogger(__name__)
24+
logger.setLevel(logging.INFO)
25+
26+
# Parse environment variables with error handling
27+
try:
28+
S3VECTOR_BUCKET = os.environ['S3VECTOR_BUCKET']
29+
S3VECTOR_INDEX = os.environ['S3VECTOR_INDEX']
30+
S3VECTOR_DIMENSIONS = int(os.environ['S3VECTOR_DIMENSIONS'])
31+
MODEL_ID = os.environ['MODEL_ID']
32+
TOP_K = int(os.environ['TOP_K'])
33+
except (KeyError, ValueError, IndexError) as e:
34+
logger.error(f"Failed to parse environment variables: {e}")
35+
raise
36+
37+
# Initialize clients
38+
s3vectors = boto3.client('s3vectors')
39+
bedrock_client = bedrock.BedrockClient()
40+
41+
def lambda_handler(event, context):
42+
"""
43+
Process a document to find similar examples using S3 Vectors similarity search.
44+
45+
Input event:
46+
{
47+
"class_label": "<class_label>",
48+
"document_texts": ["<document_text_1>", "<document_text_2>", ...],
49+
"image_content": ["<base64_image_content_1>", "<base64_image_content_2>", ...]
50+
}
51+
52+
Return format:
53+
[
54+
{
55+
"attributes_prompt": "expected attributes are: ...",
56+
"class_prompt": "This is an example of the class 'invoice'",
57+
"distance": 0.892344521145,
58+
"image_content": ["<base64_image_content_1>", "<base64_image_content_2>", ...]
59+
}
60+
]
61+
"""
62+
63+
try:
64+
logger.info("=== DYNAMIC FEW-SHOT LAMBDA INVOKED ===")
65+
logger.debug(f"Complete input event: {json.dumps(event, indent=2)}")
66+
67+
# Validate input
68+
class_label = event.get("class_label")
69+
document_texts = event.get("document_texts", [])
70+
image_content = event.get("image_content", [])
71+
72+
logger.info(f"=== INPUT VALUES ===")
73+
logger.info(f"Class label: {class_label if class_label else 'Not specified'}")
74+
logger.info(f"Document texts: {len(document_texts)}")
75+
logger.info(f"Image content: {len(image_content)}")
76+
77+
# Decode input data
78+
image_data = _decode_images(image_content)
79+
80+
# Find similar items using S3 vectors lookup from image similarity
81+
result = _s3vectors_find_similar_items(image_data)
82+
83+
# Log complete output structure
84+
logger.info(f"=== OUTPUT ANALYSIS ===")
85+
logger.debug(f"Complete result: {json.dumps(result, indent=2)}")
86+
logger.info(f"Output items: {len(result)}")
87+
88+
logger.info("=== DYNAMIC FEW-SHOT LAMBDA COMPLETED ===")
89+
return result
90+
91+
except Exception as e:
92+
logger.error(f"=== DYNAMIC FEW-SHOT LAMBDA ERROR ===")
93+
logger.error(f"Error type: {type(e).__name__}")
94+
logger.error(f"Error message: {str(e)}")
95+
logger.error(f"Input event keys: {list(event.keys()) if 'event' in locals() else 'Unknown'}")
96+
# In demo, we'll fail gracefully with detailed error info
97+
raise Exception(f"Dynamic few-shot Lambda failed: {str(e)}")
98+
99+
def _decode_images(image_content):
100+
"""Base64 decode image content to bytes"""
101+
result = []
102+
for image_base64 in image_content:
103+
image_data = base64.b64decode(image_base64)
104+
result.append(image_data)
105+
return result
106+
107+
def _encode_images(image_content):
108+
"""Base64 encode image content to JSON-serializable string"""
109+
result = []
110+
for image_bytes in image_content:
111+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
112+
result.append(image_base64)
113+
return result
114+
115+
def _s3vectors_find_similar_items(image_data):
116+
"""Find similar items for input"""
117+
118+
# find similar items based on image similarity only
119+
similar_items = {}
120+
for page_image in image_data:
121+
result = _s3vectors_find_similar_items_from_image(image_data)
122+
_merge_examples(similar_items, result)
123+
124+
# create result set
125+
result = []
126+
for key, example in similar_items.items():
127+
metadata = example.get("metadata", {})
128+
attributes_prompt = metadata.get("attributesPrompt")
129+
130+
# Only process this example if it has a non-empty attributesPrompt
131+
if not attributes_prompt or not attributes_prompt.strip():
132+
logger.info(
133+
f"Skipping example with empty attributesPrompt: {key}"
134+
)
135+
continue
136+
137+
attributes = _extract_metadata(metadata)
138+
result.append(attributes)
139+
140+
return result
141+
142+
def _s3vectors_find_similar_items_from_image(page_image):
143+
"""Search for similar items using image query"""
144+
embedding = bedrock_client.generate_embedding(
145+
image_source=page_image,
146+
model_id=MODEL_ID,
147+
dimensions=S3VECTOR_DIMENSIONS,
148+
)
149+
response = s3vectors.query_vectors(
150+
vectorBucketName=S3VECTOR_BUCKET,
151+
indexName=S3VECTOR_INDEX,
152+
queryVector={"float32": embedding},
153+
topK=TOP_K,
154+
returnDistance=True,
155+
returnMetadata=True
156+
)
157+
return response["vectors"]
158+
159+
def _merge_examples(examples, new_examples):
160+
"""
161+
Merge in-place new examples into the result list, avoiding duplicates.
162+
163+
Args:
164+
examples: Dict of existing examples
165+
new_examples: List of new examples to be merged
166+
"""
167+
for new_example in new_examples:
168+
key = new_example["key"]
169+
new_distance = new_example.get("distance", 1.0)
170+
171+
# update example
172+
if combined_examples.get(key):
173+
existing_distance = combined_examples[key].get("distance", 1.0)
174+
examples[key]["distance"] = min(new_distance, existing_distance)
175+
examples[key]["metadata"] = new_example.get("metadata")
176+
# insert example
177+
else:
178+
examples[key] = {
179+
"distance": new_distance,
180+
"metadata": new_example.get("metadata")
181+
}
182+
183+
def _extract_metadata(metadata, distance):
184+
"""Create result object from S3 vectors metadata"""
185+
# Result object attributes
186+
attributes = {
187+
"attributes_prompt": metadata.get("attributesPrompt"),
188+
"class_prompt": metadata.get("classPrompt"),
189+
"distance": distance,
190+
}
191+
192+
image_path = metadata.get("imagePath")
193+
if image_path:
194+
image_data = _get_image_data_from_s3_path(image_path)
195+
encoded_images = _encode_images(image_data)
196+
attributes["image_content"] = encoded_images
197+
198+
return attributes
199+
200+
def _get_image_data_from_s3_path(image_path):
201+
"""
202+
Load images from image path
203+
204+
Args:
205+
image_path: Path to image file, directory, or S3 prefix
206+
207+
Returns:
208+
List of images (bytes)
209+
"""
210+
# Get list of image files from the path (supports directories/prefixes)
211+
image_files = _get_image_files_from_s3_path(image_path)
212+
image_content = []
213+
214+
# Process each image file
215+
for image_file_path in image_files:
216+
try:
217+
# Load image content
218+
if image_file_path.startswith("s3://"):
219+
# Direct S3 URI
220+
image_bytes = s3.get_binary_content(image_file_path)
221+
else:
222+
raise ValueError(
223+
f"Invalid file path {image_path} - expecting S3 path"
224+
)
225+
226+
image_content.append(image_bytes)
227+
except Exception as e:
228+
logger.warning(f"Failed to load image {image_file_path}: {e}")
229+
continue
230+
231+
return image_content
232+
233+
def _get_image_files_from_s3_path(image_path):
234+
"""
235+
Get list of image files from an S3 path.
236+
237+
Args:
238+
image_path: Path to image file, directory, or S3 prefix
239+
240+
Returns:
241+
List of image file paths/URIs sorted by filename
242+
"""
243+
# Handle S3 URIs
244+
if not image_path.startswith("s3://"):
245+
raise ValueError(
246+
f"Invalid file path {image_path} - expecting S3 URI"
247+
)
248+
249+
# Check if it's a direct file or a prefix
250+
if image_path.endswith(
251+
(".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif", ".webp")
252+
):
253+
# Direct S3 file
254+
return [image_path]
255+
else:
256+
# S3 prefix - list all images
257+
return s3.list_images_from_path(image_path)

0 commit comments

Comments
 (0)