|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# SPDX-License-Identifier: MIT-0 |
| 3 | + |
| 4 | +""" |
| 5 | +Lambda function to provide examples with ground truth data based on S3 Vectors lookup. |
| 6 | +
|
| 7 | +Key Features Demonstrated: |
| 8 | +- Dynamically retrieve similar examples based on document content using vector similarity search |
| 9 | +- Provide few-shot examples to improve extraction accuracy through example-based prompting |
| 10 | +- Leverage S3 Vectors for efficient similarity search across large example datasets |
| 11 | +- Integrate multimodal embeddings using Amazon Nova models for image-based similarity |
| 12 | +- Customize example selection based on document characteristics and business rules |
| 13 | +""" |
| 14 | + |
| 15 | +import json |
| 16 | +import logging |
| 17 | +import base64 |
| 18 | +import boto3 |
| 19 | +import os |
| 20 | + |
| 21 | +from idp_common import bedrock, s3 |
| 22 | + |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | +logger.setLevel(logging.INFO) |
| 25 | + |
| 26 | +# Parse environment variables with error handling |
| 27 | +try: |
| 28 | + S3VECTOR_BUCKET = os.environ['S3VECTOR_BUCKET'] |
| 29 | + S3VECTOR_INDEX = os.environ['S3VECTOR_INDEX'] |
| 30 | + S3VECTOR_DIMENSIONS = int(os.environ['S3VECTOR_DIMENSIONS']) |
| 31 | + MODEL_ID = os.environ['MODEL_ID'] |
| 32 | + TOP_K = int(os.environ['TOP_K']) |
| 33 | +except (KeyError, ValueError, IndexError) as e: |
| 34 | + logger.error(f"Failed to parse environment variables: {e}") |
| 35 | + raise |
| 36 | + |
| 37 | +# Initialize clients |
| 38 | +s3vectors = boto3.client('s3vectors') |
| 39 | +bedrock_client = bedrock.BedrockClient() |
| 40 | + |
| 41 | +def lambda_handler(event, context): |
| 42 | + """ |
| 43 | + Process a document to find similar examples using S3 Vectors similarity search. |
| 44 | +
|
| 45 | + Input event: |
| 46 | + { |
| 47 | + "class_label": "<class_label>", |
| 48 | + "document_texts": ["<document_text_1>", "<document_text_2>", ...], |
| 49 | + "image_content": ["<base64_image_content_1>", "<base64_image_content_2>", ...] |
| 50 | + } |
| 51 | +
|
| 52 | + Return format: |
| 53 | + [ |
| 54 | + { |
| 55 | + "attributes_prompt": "expected attributes are: ...", |
| 56 | + "class_prompt": "This is an example of the class 'invoice'", |
| 57 | + "distance": 0.892344521145, |
| 58 | + "image_content": ["<base64_image_content_1>", "<base64_image_content_2>", ...] |
| 59 | + } |
| 60 | + ] |
| 61 | + """ |
| 62 | + |
| 63 | + try: |
| 64 | + logger.info("=== DYNAMIC FEW-SHOT LAMBDA INVOKED ===") |
| 65 | + logger.debug(f"Complete input event: {json.dumps(event, indent=2)}") |
| 66 | + |
| 67 | + # Validate input |
| 68 | + class_label = event.get("class_label") |
| 69 | + document_texts = event.get("document_texts", []) |
| 70 | + image_content = event.get("image_content", []) |
| 71 | + |
| 72 | + logger.info(f"=== INPUT VALUES ===") |
| 73 | + logger.info(f"Class label: {class_label if class_label else 'Not specified'}") |
| 74 | + logger.info(f"Document texts: {len(document_texts)}") |
| 75 | + logger.info(f"Image content: {len(image_content)}") |
| 76 | + |
| 77 | + # Decode input data |
| 78 | + image_data = _decode_images(image_content) |
| 79 | + |
| 80 | + # Find similar items using S3 vectors lookup from image similarity |
| 81 | + result = _s3vectors_find_similar_items(image_data) |
| 82 | + |
| 83 | + # Log complete output structure |
| 84 | + logger.info(f"=== OUTPUT ANALYSIS ===") |
| 85 | + logger.debug(f"Complete result: {json.dumps(result, indent=2)}") |
| 86 | + logger.info(f"Output items: {len(result)}") |
| 87 | + |
| 88 | + logger.info("=== DYNAMIC FEW-SHOT LAMBDA COMPLETED ===") |
| 89 | + return result |
| 90 | + |
| 91 | + except Exception as e: |
| 92 | + logger.error(f"=== DYNAMIC FEW-SHOT LAMBDA ERROR ===") |
| 93 | + logger.error(f"Error type: {type(e).__name__}") |
| 94 | + logger.error(f"Error message: {str(e)}") |
| 95 | + logger.error(f"Input event keys: {list(event.keys()) if 'event' in locals() else 'Unknown'}") |
| 96 | + # In demo, we'll fail gracefully with detailed error info |
| 97 | + raise Exception(f"Dynamic few-shot Lambda failed: {str(e)}") |
| 98 | + |
| 99 | +def _decode_images(image_content): |
| 100 | + """Base64 decode image content to bytes""" |
| 101 | + result = [] |
| 102 | + for image_base64 in image_content: |
| 103 | + image_data = base64.b64decode(image_base64) |
| 104 | + result.append(image_data) |
| 105 | + return result |
| 106 | + |
| 107 | +def _encode_images(image_content): |
| 108 | + """Base64 encode image content to JSON-serializable string""" |
| 109 | + result = [] |
| 110 | + for image_bytes in image_content: |
| 111 | + image_base64 = base64.b64encode(image_bytes).decode("utf-8") |
| 112 | + result.append(image_base64) |
| 113 | + return result |
| 114 | + |
| 115 | +def _s3vectors_find_similar_items(image_data): |
| 116 | + """Find similar items for input""" |
| 117 | + |
| 118 | + # find similar items based on image similarity only |
| 119 | + similar_items = {} |
| 120 | + for page_image in image_data: |
| 121 | + result = _s3vectors_find_similar_items_from_image(image_data) |
| 122 | + _merge_examples(similar_items, result) |
| 123 | + |
| 124 | + # create result set |
| 125 | + result = [] |
| 126 | + for key, example in similar_items.items(): |
| 127 | + metadata = example.get("metadata", {}) |
| 128 | + attributes_prompt = metadata.get("attributesPrompt") |
| 129 | + |
| 130 | + # Only process this example if it has a non-empty attributesPrompt |
| 131 | + if not attributes_prompt or not attributes_prompt.strip(): |
| 132 | + logger.info( |
| 133 | + f"Skipping example with empty attributesPrompt: {key}" |
| 134 | + ) |
| 135 | + continue |
| 136 | + |
| 137 | + attributes = _extract_metadata(metadata) |
| 138 | + result.append(attributes) |
| 139 | + |
| 140 | + return result |
| 141 | + |
| 142 | +def _s3vectors_find_similar_items_from_image(page_image): |
| 143 | + """Search for similar items using image query""" |
| 144 | + embedding = bedrock_client.generate_embedding( |
| 145 | + image_source=page_image, |
| 146 | + model_id=MODEL_ID, |
| 147 | + dimensions=S3VECTOR_DIMENSIONS, |
| 148 | + ) |
| 149 | + response = s3vectors.query_vectors( |
| 150 | + vectorBucketName=S3VECTOR_BUCKET, |
| 151 | + indexName=S3VECTOR_INDEX, |
| 152 | + queryVector={"float32": embedding}, |
| 153 | + topK=TOP_K, |
| 154 | + returnDistance=True, |
| 155 | + returnMetadata=True |
| 156 | + ) |
| 157 | + return response["vectors"] |
| 158 | + |
| 159 | +def _merge_examples(examples, new_examples): |
| 160 | + """ |
| 161 | + Merge in-place new examples into the result list, avoiding duplicates. |
| 162 | +
|
| 163 | + Args: |
| 164 | + examples: Dict of existing examples |
| 165 | + new_examples: List of new examples to be merged |
| 166 | + """ |
| 167 | + for new_example in new_examples: |
| 168 | + key = new_example["key"] |
| 169 | + new_distance = new_example.get("distance", 1.0) |
| 170 | + |
| 171 | + # update example |
| 172 | + if combined_examples.get(key): |
| 173 | + existing_distance = combined_examples[key].get("distance", 1.0) |
| 174 | + examples[key]["distance"] = min(new_distance, existing_distance) |
| 175 | + examples[key]["metadata"] = new_example.get("metadata") |
| 176 | + # insert example |
| 177 | + else: |
| 178 | + examples[key] = { |
| 179 | + "distance": new_distance, |
| 180 | + "metadata": new_example.get("metadata") |
| 181 | + } |
| 182 | + |
| 183 | +def _extract_metadata(metadata, distance): |
| 184 | + """Create result object from S3 vectors metadata""" |
| 185 | + # Result object attributes |
| 186 | + attributes = { |
| 187 | + "attributes_prompt": metadata.get("attributesPrompt"), |
| 188 | + "class_prompt": metadata.get("classPrompt"), |
| 189 | + "distance": distance, |
| 190 | + } |
| 191 | + |
| 192 | + image_path = metadata.get("imagePath") |
| 193 | + if image_path: |
| 194 | + image_data = _get_image_data_from_s3_path(image_path) |
| 195 | + encoded_images = _encode_images(image_data) |
| 196 | + attributes["image_content"] = encoded_images |
| 197 | + |
| 198 | + return attributes |
| 199 | + |
| 200 | +def _get_image_data_from_s3_path(image_path): |
| 201 | + """ |
| 202 | + Load images from image path |
| 203 | +
|
| 204 | + Args: |
| 205 | + image_path: Path to image file, directory, or S3 prefix |
| 206 | +
|
| 207 | + Returns: |
| 208 | + List of images (bytes) |
| 209 | + """ |
| 210 | + # Get list of image files from the path (supports directories/prefixes) |
| 211 | + image_files = _get_image_files_from_s3_path(image_path) |
| 212 | + image_content = [] |
| 213 | + |
| 214 | + # Process each image file |
| 215 | + for image_file_path in image_files: |
| 216 | + try: |
| 217 | + # Load image content |
| 218 | + if image_file_path.startswith("s3://"): |
| 219 | + # Direct S3 URI |
| 220 | + image_bytes = s3.get_binary_content(image_file_path) |
| 221 | + else: |
| 222 | + raise ValueError( |
| 223 | + f"Invalid file path {image_path} - expecting S3 path" |
| 224 | + ) |
| 225 | + |
| 226 | + image_content.append(image_bytes) |
| 227 | + except Exception as e: |
| 228 | + logger.warning(f"Failed to load image {image_file_path}: {e}") |
| 229 | + continue |
| 230 | + |
| 231 | + return image_content |
| 232 | + |
| 233 | +def _get_image_files_from_s3_path(image_path): |
| 234 | + """ |
| 235 | + Get list of image files from an S3 path. |
| 236 | +
|
| 237 | + Args: |
| 238 | + image_path: Path to image file, directory, or S3 prefix |
| 239 | +
|
| 240 | + Returns: |
| 241 | + List of image file paths/URIs sorted by filename |
| 242 | + """ |
| 243 | + # Handle S3 URIs |
| 244 | + if not image_path.startswith("s3://"): |
| 245 | + raise ValueError( |
| 246 | + f"Invalid file path {image_path} - expecting S3 URI" |
| 247 | + ) |
| 248 | + |
| 249 | + # Check if it's a direct file or a prefix |
| 250 | + if image_path.endswith( |
| 251 | + (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif", ".webp") |
| 252 | + ): |
| 253 | + # Direct S3 file |
| 254 | + return [image_path] |
| 255 | + else: |
| 256 | + # S3 prefix - list all images |
| 257 | + return s3.list_images_from_path(image_path) |
0 commit comments