diff --git a/lib/idp_common_pkg/idp_common/bedrock/README.md b/lib/idp_common_pkg/idp_common/bedrock/README.md index 58c5bd64..b0a67cf7 100644 --- a/lib/idp_common_pkg/idp_common/bedrock/README.md +++ b/lib/idp_common_pkg/idp_common/bedrock/README.md @@ -73,6 +73,47 @@ embedding = client.generate_embedding( # Use embedding for vector search, clustering, etc. ``` +Amazon Titan Multimodal Embeddings support both text and image at the same time. The resulting embeddings vector averages the text embeddings and image embeddings vectors. + +```python +from idp_common.bedrock.client import BedrockClient + +with open("/path/to/document.png", "rb") as image_file: + image_data = image_file.read() + +client = BedrockClient() +embedding = client.generate_embedding( + text="This document contains information about loan applications.", + image_source=image_data, + model_id="amazon.titan-embed-image-v1" +) +``` + +The image source can also be an S3 URI: + +```python +from idp_common.bedrock.client import BedrockClient + +client = BedrockClient() +embedding = client.generate_embedding( + image_data="s3://bucket/key", + model_id="amazon.titan-embed-image-v1" +) +``` + +Amazon Nova Multimodal Embeddings with 3072 dimension size: + +```python +from idp_common.bedrock.client import BedrockClient + +client = BedrockClient() +embedding = client.generate_embedding( + image_data="s3://bucket/key", + model_id="amazon.nova-2-multimodal-embeddings-v1:0", + dimensions=3072 +) +``` + ## Prompt Caching with CachePoint Prompt caching is a powerful feature in Amazon Bedrock that significantly reduces response latency for workloads with repetitive contexts. The Bedrock client provides built-in support for this via the `<>` tag. diff --git a/lib/idp_common_pkg/idp_common/bedrock/client.py b/lib/idp_common_pkg/idp_common/bedrock/client.py index 3f19ffe5..7e14d4bb 100644 --- a/lib/idp_common_pkg/idp_common/bedrock/client.py +++ b/lib/idp_common_pkg/idp_common/bedrock/client.py @@ -16,6 +16,7 @@ import copy import random import socket +import base64 from typing import Dict, Any, List, Optional, Union, Tuple, Type from botocore.config import Config from botocore.exceptions import ( @@ -26,7 +27,6 @@ ) from urllib3.exceptions import ReadTimeoutError as Urllib3ReadTimeoutError - # Dummy exception classes for requests timeouts if requests is not available class _RequestsReadTimeout(Exception): """Fallback exception class when requests library is not available.""" @@ -711,22 +711,35 @@ def get_guardrail_config(self) -> Optional[Dict[str, str]]: def generate_embedding( self, - text: str, + text: str = "", + image_source: Optional[Union[str, bytes]] = None, model_id: str = "amazon.titan-embed-text-v1", + dimensions: int = 1024, max_retries: Optional[int] = None, ) -> List[float]: """ - Generate an embedding vector for the given text using Amazon Bedrock. + Generate an embedding vector for the given text or image_source using Amazon Bedrock. + At least one of text or the image is required to generate the embedding. + For Titan Multimodal embedding models, you can include both to create an embeddings query vector that averages the resulting text embeddings and image embeddings vectors. + For Nova Multimodal embedding models, exactly one of text or the image must be present, but not both. Args: text: The text to generate embeddings for + image_source: The image to generate embeddings for (can be either an S3 URI (s3://bucket/key) or raw image bytes) model_id: The embedding model ID to use (default: amazon.titan-embed-text-v1) max_retries: Optional override for the instance's max_retries setting + dimensions: Length of the output embeddings vector Returns: List of floats representing the embedding vector """ - if not text or not isinstance(text, str): + # requires PIL + from idp_common.image import ( + prepare_image, + prepare_bedrock_image_attachment + ) + + if (not text or not isinstance(text, str)) and (not image_source): # Return an empty vector for empty input return [] @@ -741,12 +754,61 @@ def generate_embedding( # Normalize whitespace and prepare the input text normalized_text = " ".join(text.split()) + # Convert image to base64 + if image_source: + image_bytes = prepare_image(image_source) + image_base64 = base64.b64encode(image_bytes).decode('utf-8') + + dimensions = int(dimensions) + # Prepare the request body based on the model - if "amazon.titan-embed" in model_id: - request_body = json.dumps({"inputText": normalized_text}) + payload_body: Dict[str, Any] = {} + + if "amazon.titan-embed-text" in model_id: + if not normalized_text: + raise ValueError( + "Amazon Titan Text models require a text parameter to generate embeddings for." + ) + payload_body = { + "inputText": normalized_text, + "dimensions": dimensions, + } + elif "amazon.titan-embed-image" in model_id: + payload_body = { + "embeddingConfig": { + "outputEmbeddingLength": dimensions, + } + } + if normalized_text: + payload_body["inputText"] = normalized_text + if image_base64: + payload_body["inputImage"] = image_base64 + elif "amazon.nova-2-multimodal-embeddings" in model_id: + if normalized_text and image_source: + raise ValueError( + "Amazon Nova Multimodal Embedding models require exactly one of text or image parameter, but noth both at the same time." + ) + payload_body = { + "taskType": "SINGLE_EMBEDDING", + "singleEmbeddingParams": { + "embeddingPurpose": "GENERIC_INDEX", + "embeddingDimension": dimensions, + } + } + if normalized_text: + payload_body["singleEmbeddingParams"]["text"] = {"truncationMode": "END", "value": normalized_text} + if image_source: + payload_body["singleEmbeddingParams"].update(prepare_bedrock_image_attachment(image_bytes)) # detect image format + payload_body["singleEmbeddingParams"]["image"]["source"]["bytes"] = image_base64 else: # Default format for other models - request_body = json.dumps({"text": normalized_text}) + if not normalized_text: + raise ValueError( + "Default format requires a text parameter to generate embeddings for." + ) + payload_body = {"text": normalized_text} + + request_body = json.dumps(payload_body) # Call the recursive embedding function return self._generate_embedding_with_retry( @@ -805,6 +867,10 @@ def _generate_embedding_with_retry( # Handle different response formats based on the model if "amazon.titan-embed" in model_id: embedding = response_body.get("embedding", []) + elif "amazon.titan-embed-image" in model_id: + embedding = response_body.get("embedding", []) + elif "amazon.nova-2-multimodal-embeddings" in model_id: + embedding = response_body["embeddings"][0]["embedding"] else: # Default extraction format embedding = response_body.get("embedding", []) diff --git a/lib/idp_common_pkg/idp_common/extraction/service.py b/lib/idp_common_pkg/idp_common/extraction/service.py index 3cd83a9f..2bf21e77 100644 --- a/lib/idp_common_pkg/idp_common/extraction/service.py +++ b/lib/idp_common_pkg/idp_common/extraction/service.py @@ -10,6 +10,7 @@ from __future__ import annotations +import base64 import json import logging import os @@ -433,6 +434,53 @@ def _make_json_serializable(self, obj: Any) -> Any: # Convert non-serializable objects to string representation return str(obj) + def _convert_image_uris_to_bytes_in_content( + self, content: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """ + Convert image URIs back to bytes in content array after Lambda processing. + + Args: + content: Content array from Lambda that may contain image URIs + + Returns: + Content array with image bytes restored + """ + converted_content = [] + + for item in content: + if "image_uri" in item: + image_uri = item["image_uri"] + + # Load image content + if image_uri.startswith("s3://"): + # Direct S3 URI + logger.info(f"Retrieving image {image_uri}") + image_bytes = s3.get_binary_content(image_uri) + else: + raise ValueError( + f"Invalid file path {image_uri} - expecting S3 path" + ) + + converted_item = image.prepare_bedrock_image_attachment(image_bytes) + elif "image_base64" in item: + image_base64 = item["image_base64"] + + # Decode image content + image_bytes = base64.b64decode(image_base64) + + converted_item = image.prepare_bedrock_image_attachment(image_bytes) + elif "image" in item: + # Keep existing image objects as-is + converted_item = item.copy() + else: + # Keep non-image items as-is + converted_item = item.copy() + + converted_content.append(converted_item) + + return converted_content + def _invoke_custom_prompt_lambda( self, lambda_arn: str, payload: dict[str, Any] ) -> dict[str, Any]: @@ -486,6 +534,13 @@ def _invoke_custom_prompt_lambda( logger.error(error_msg) raise Exception(error_msg) + # Convert image URIs to bytes in the response + result["task_prompt_content"] = ( + self._convert_image_uris_to_bytes_in_content( + result["task_prompt_content"] + ) + ) + return result except Exception as e: diff --git a/patterns/pattern-2/template.yaml b/patterns/pattern-2/template.yaml index 6605c8af..2bd9c364 100644 --- a/patterns/pattern-2/template.yaml +++ b/patterns/pattern-2/template.yaml @@ -1026,7 +1026,7 @@ Resources: order: 7 custom_prompt_lambda_arn: type: string - description: "(Optional) ARN of a Lambda function to generate custom extraction prompts. Function name must start with 'GENAIIDP-'. If not provided, default prompts will be used. The Lambda function receives the complete config, prompt placeholders, default task prompt content, and serialized document, and returns custom system_prompt and task_prompt_content. Example: arn:${AWS::Partition}:lambda:us-east-1:123456789012:function:GENAIIDP-my-extractor" + description: !Sub "(Optional) ARN of a Lambda function to generate custom extraction prompts. Function name must start with 'GENAIIDP-'. If not provided, default prompts will be used. The Lambda function receives the complete config, prompt placeholders, default task prompt content, and serialized document, and returns custom system_prompt and task_prompt_content. Example: arn:${AWS::Partition}:lambda:us-east-1:123456789012:function:GENAIIDP-my-extractor" order: 8 assessment: order: 5 diff --git a/plugins/dynamic-few-shot-lambda/.gitignore b/plugins/dynamic-few-shot-lambda/.gitignore new file mode 100644 index 00000000..f3c07f0d --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/.gitignore @@ -0,0 +1 @@ +datasets/ \ No newline at end of file diff --git a/plugins/dynamic-few-shot-lambda/README.md b/plugins/dynamic-few-shot-lambda/README.md new file mode 100644 index 00000000..ec58b3f6 --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/README.md @@ -0,0 +1,454 @@ +# Dynamic Few-Shot Prompting Lambda - Complete Guide + +This directory contains the **complete implementation** of the dynamic few-shot prompting Lambda function for GenAI IDP Accelerator. This Lambda function integrates with Pattern 2 extraction as a custom prompt generator, dynamically retrieving similar examples using S3 Vectors similarity search to improve extraction accuracy. + +## 🎯 Overview + +The dynamic few-shot prompting Lambda function allows you to: + +- **Dynamically retrieve similar examples** based on document content using vector similarity search +- **Automatically inject few-shot examples** into extraction prompts using the `{FEW_SHOT_EXAMPLES}` placeholder +- **Leverage S3 Vectors** for efficient similarity search across large example datasets +- **Integrate multimodal embeddings** using Amazon Nova models for image-based similarity +- **Seamlessly integrate** with existing IDP extraction workflows as a custom prompt Lambda + +## 📁 Files in This Directory + +- **`src/GENAIIDP-dynamic-few-shot.py`** - Dynamic few-shot Lambda function with S3 Vectors lookup +- **`src/requirements.txt`** - Python dependencies for the Lambda function +- **`template.yml`** - CloudFormation SAM template to deploy the Lambda function +- **`README.md`** - This comprehensive documentation and guide + +## 🏗️ Architecture + +```mermaid +flowchart TD + A[IDP Document Processing] --> B{Custom Prompt Lambda ARN configured?} + B -->|No| C[Use Default Task Prompt] + B -->|Yes| D[Invoke Dynamic Few-Shot Lambda] + + subgraph "Lambda Function: GENAIIDP-dynamic-few-shot" + D --> E[Receive IDP Context & Placeholders] + E --> F[Extract Document Images from DOCUMENT_IMAGE] + F --> G[Generate Nova Multimodal Embeddings] + G --> H[Query S3 Vectors Index] + H --> I[Filter by Distance Threshold] + I --> J[Merge & Deduplicate Results] + J --> K[Load Example Images from S3] + K --> L[Build Prompt Content Array] + L --> M[Replace FEW_SHOT_EXAMPLES Placeholder] + end + + M --> N[Return Modified Task Prompt Content] + C --> O[Continue with Bedrock Extraction] + N --> O + + subgraph "Input Payload" + P[config: IDP Configuration] + Q[prompt_placeholders: DOCUMENT_TEXT, DOCUMENT_CLASS, etc.] + R[default_task_prompt_content: Original prompt] + S[serialized_document: Document metadata] + end + + subgraph "Output Payload" + T[system_prompt: Unchanged] + U[task_prompt_content: Array with Prompt segments and Example images] + end + + D -.-> P + D -.-> Q + D -.-> R + D -.-> S + + N -.-> T + N -.-> U + + subgraph "S3 Vectors Infrastructure" + X[Vector Bucket: Encrypted storage] + Y[Vector Index: 3072-dim cosine similarity] + Z[Metadata: classPrompt, attributesPrompt, imagePath] + end + + H -.-> X + H -.-> Y + H -.-> Z +``` + +## Quick Start + +### Step 1: Deploy the Dynamic-few shot Stack + +```bash +# Navigate to the dynamic-few-shot-lambda directory +cd plugins/dynamic-few-shot-lambda + +# Deploy using AWS SAM +sam deploy --guided +``` + +### Step 2: Get the Lambda ARN + +After deployment, get the ARN from CloudFormation outputs: + +```bash +aws cloudformation describe-stacks \ + --stack-name GENAIIDP-dynamic-few-shot-stack \ + --query 'Stacks[0].Outputs[?OutputKey==`DynamicFewShotFunctionArn`].OutputValue' \ + --output text +``` + +### Step 3: Populate the Examples Dataset + +Use the [fewshot_dataset_import.ipynb](notebooks/fewshot_dataset_import.ipynb) notebook to import a dataset into S3 Vectors, or manually upload your example documents and metadata to the S3 bucket and vector index created by the stack. + +### Step 4: Configure IDP to Use Dynamic Few-Shot + +Add the Lambda ARN to your IDP extraction configuration: + +```yaml +extraction: + custom_prompt_lambda_arn: "arn:aws:lambda:region:account:function:GENAIIDP-dynamic-few-shot" +``` + +**Important**: Your extraction task prompt must include the `{FEW_SHOT_EXAMPLES}` placeholder where you want the dynamic examples to be inserted. + +### Step 5: Run the Demo Notebook + +0. Run `notebooks/examples` steps 0, 1, 2 +1. Open `notebooks/examples/step3_extraction_with_custom_lambda.ipynb`. In section 3, set `DEMO_LAMBDA_ARN` to `arn:aws:lambda:region:account:function:GENAIIDP-dynamic-few-shot` +2. Run all cells to see the comparison + +## Lambda Interface + +### Input Payload Structure + +The Lambda receives the full IDP context as a custom prompt Lambda: + +```json +{ + "config": { + "extraction": {...}, + "classes": [...], + ... + }, + "prompt_placeholders": { + "DOCUMENT_TEXT": "Full OCR text from all pages", + "DOCUMENT_CLASS": "invoice", + "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": "LineItems: List of line items in the invoice...", + "DOCUMENT_IMAGE": ["s3://bucket/document/page1.jpg", "s3://bucket/document/page2.jpg"] + }, + "default_task_prompt_content": [ + {"text": "Resolved default task prompt..."}, + {"image_uri": "s3://..."}, // if images present + {"cachePoint": true} // if cache points present + ], + "serialized_document": { + "id": "document-123", + "input_bucket": "my-bucket", + "pages": {...}, + "sections": [...], + ... + } +} +``` + +### Output Payload Structure + +The Lambda returns modified prompt content with dynamic few-shot examples: + +```json +{ + "system_prompt": "Custom system prompt text", + "task_prompt_content": [ + {"text": "Extract the following attributes from this invoice document:\n\nLineItems: List of line items in the invoice...\n\n"}, + {"text": "expected attributes are:\n \"invoice_number\": \"INV-2024-001\",\n \"total_amount\": \"$1,250.00\""}, + {"image_uri": "s3://examples-bucket/invoices/example-001/page1.jpg"}, + {"text": "\n\n<>\n\nDocument content:\nINVOICE\nInvoice #: INV-2024-002..."} + ] +} +``` + +## Core Functionality + +### 1. Custom Prompt Integration + +The Lambda integrates with IDP's custom prompt system by: +- Receiving the full extraction context and configuration +- Processing the `{FEW_SHOT_EXAMPLES}` placeholder in task prompts +- Returning modified prompt content with dynamically retrieved examples + +### 2. Vector Similarity Search + +The Lambda uses Amazon Nova multimodal embeddings to find similar examples: + +```python +# Generate embedding from document image +embedding = bedrock_client.generate_embedding( + image_source=page_image, + model_id=MODEL_ID, + dimensions=S3VECTOR_DIMENSIONS, +) + +# Query S3 Vectors for similar examples +response = s3vectors.query_vectors( + vectorBucketName=S3VECTOR_BUCKET, + indexName=S3VECTOR_INDEX, + queryVector={"float32": embedding}, + topK=TOP_K, + returnDistance=True, + returnMetadata=True +) +``` + +### 3. Example Merging and Deduplication + +Multiple document images are processed and results are merged to avoid duplicates: + +```python +def _merge_examples(examples, new_examples): + """Merge examples, keeping the best similarity score for duplicates""" + for new_example in new_examples: + key = new_example["key"] + new_distance = new_example.get("distance", 1.0) + + if examples.get(key): + existing_distance = examples[key].get("distance", 1.0) + examples[key]["distance"] = min(new_distance, existing_distance) +``` + +### 4. Prompt Content Building + +The Lambda builds structured prompt content handling multiple placeholders: + +```python +def _build_prompt_content(prompt_template, substitutions, image_content): + """ + Build prompt content array handling FEW_SHOT_EXAMPLES and DOCUMENT_IMAGE placeholders. + + Handles: + - {FEW_SHOT_EXAMPLES}: Inserts few-shot examples from S3 Vectors + - {DOCUMENT_IMAGE}: Inserts images at specific location + - Regular text placeholders: DOCUMENT_TEXT, DOCUMENT_CLASS, etc. + """ +``` + +## Configuration + +### Environment Variables + +The Lambda function uses these environment variables (set by the CloudFormation template): + +- `S3VECTOR_BUCKET` - Name of the S3 Vectors bucket +- `S3VECTOR_INDEX` - Name of the S3 Vectors index +- `S3VECTOR_DIMENSIONS` - Embedding dimensions (e.g. `3072` for Nova Multimodal Embedding model) +- `MODEL_ID` - Bedrock model ID for embeddings (e.g. `amazon.nova-2-multimodal-embeddings-v1:0`) +- `TOP_K` - Number of similar examples to retrieve (default: 3) +- `THRESHOLD` - Maximum distance threshold for filtering results (default: 0.5) +- `LOG_LEVEL` - Logging level (default: INFO) + +### S3 Vectors Configuration + +The stack creates: +- **Vector Bucket**: Encrypted S3 bucket for vector storage +- **Vector Index**: Cosine similarity index with 3072 dimensions +- **Metadata Configuration**: Stores `classPrompt`, `attributesPrompt`, and `imagePath` as non-filterable metadata keys + +## Monitoring and Troubleshooting + +### CloudWatch Logs + +Monitor the Lambda function logs: +- `/aws/lambda/GENAIIDP-dynamic-few-shot` - Dynamic few-shot Lambda logs + +### Key Log Messages + +**Successful Operation:** +``` +=== DYNAMIC FEW-SHOT LAMBDA INVOKED === +=== EXTRACTION CONFIG === +Model: anthropic.claude-3-5-sonnet-20241022-v2:0 +=== HANDLE INPUT DOCUMENT === +=== OUTPUT ANALYSIS === +Output keys: ['system_prompt', 'task_prompt_content'] +Task prompt content items: 5 +=== DYNAMIC FEW-SHOT LAMBDA COMPLETED === +``` + +**Error Conditions:** +``` +Failed to parse environment variables: KeyError('S3VECTOR_BUCKET') +Skipping example with empty attributesPrompt: example_key +Skipping example with distance 0.8 above threshold 0.5: example_key +Invalid file path /local/path - expecting S3 URI +``` + +### Performance Monitoring + +Key metrics to monitor: +- **Lambda Duration**: Time to retrieve and process examples +- **S3 Vectors Query Time**: Vector similarity search performance +- **Example Count**: Number of examples returned per request +- **Error Rate**: Failed example retrievals + +## Example Dataset Structure + +### Vector Metadata Format + +Each vector in the S3 Vectors index should have metadata: + +```json +{ + "classLabel": "invoice", + "classPrompt": "This is an example of the class 'invoice'", + "attributesPrompt": "Expected attributes are: invoice_number [Unique identifier], invoice_date [Invoice date], total_amount [Total amount]...", + "imagePath": "s3://examples-bucket/invoices/example-001/" +} +``` + +### Image Storage Structure + +Example images should be stored in S3 with paths referenced in metadata: + +``` +s3://examples-bucket/ +├── invoices/ +│ ├── example-001/ +│ │ ├── page-1.jpg +│ │ └── page-2.jpg +│ └── example-002/ +│ └── invoice.png +└── receipts/ + ├── example-003/ + │ └── receipt.jpg + └── example-004/ + └── receipt.png +``` + +## Production Considerations + +### 1. Example Dataset Management + +- **Quality Control**: Ensure high-quality, representative examples +- **Regular Updates**: Keep examples current with document variations +- **Metadata Consistency**: Maintain consistent attribute descriptions +- **Image Optimization**: Use appropriate image formats and sizes + +### 2. Performance Optimization + +```python +# Cache frequently accessed examples +# Optimize vector dimensions for your use case +# Use appropriate TOP_K values (typically 2-5) +# Consider batch processing for multiple documents +``` + +### 3. Security Considerations + +- **Access Control**: Restrict access to example datasets +- **Data Privacy**: Ensure examples don't contain sensitive information +- **Encryption**: Use appropriate encryption for stored examples +- **Audit Logging**: Log example usage for compliance + +### 4. Cost Optimization + +- **Vector Index Size**: Monitor storage costs for large example sets +- **Embedding Generation**: Optimize frequency of embedding updates +- **Lambda Memory**: Right-size memory allocation based on usage +- **S3 Storage Classes**: Use appropriate storage classes for examples + +## Deployment Options + +### Option 1: AWS SAM (Recommended) +```bash +sam build +sam deploy --guided +``` + +### Option 2: AWS CLI +```bash +# Package and deploy +aws cloudformation package \ + --template-file template.yml \ + --s3-bucket your-deployment-bucket \ + --output-template-file packaged-template.yml + +aws cloudformation deploy \ + --template-file packaged-template.yml \ + --stack-name GENAIIDP-dynamic-few-shot-stack \ + --capabilities CAPABILITY_IAM +``` + +## Cleanup + +To remove the dynamic-few shot resources: + +```bash +# Delete the CloudFormation stack +aws cloudformation delete-stack --stack-name GENAIIDP-dynamic-few-shot-stack + +# Note: S3 buckets with retention policy will be retained +``` + +## Integration with IDP + +### Configuration in IDP Stack + +Add the dynamic few-shot Lambda ARN to your IDP extraction configuration: + +```yaml +extraction: + custom_prompt_lambda_arn: "arn:aws:lambda:region:account:function:GENAIIDP-dynamic-few-shot" +``` + +### Required Task Prompt Configuration + +**Critical**: Your extraction task prompt must include the `{FEW_SHOT_EXAMPLES}` placeholder where you want the dynamic examples to be inserted. The Lambda specifically looks for this placeholder and replaces it with retrieved examples. + +### Expected Behavior + +When configured: +1. IDP processes document and extracts images/text +2. IDP invokes the dynamic few-shot Lambda with full extraction context +3. Lambda generates embeddings from document images using Amazon Nova +4. Lambda queries S3 Vectors to find similar examples +5. Lambda loads example images and metadata from S3 +6. Lambda builds modified prompt content with examples inserted at `{FEW_SHOT_EXAMPLES}` location +7. IDP uses the modified prompt content for Bedrock extraction +8. Bedrock uses the dynamic examples to improve extraction accuracy + +### Prompt Flow Example + +**Original Task Prompt:** +``` +Extract attributes from this invoice: +{ATTRIBUTE_NAMES_AND_DESCRIPTIONS} +{FEW_SHOT_EXAMPLES} +<> +Document: {DOCUMENT_TEXT} +``` + +**After Lambda Processing:** +``` +Extract attributes from this invoice: +invoice_number [Unique identifier]... + +expected attributes are: + "invoice_number": "INV-2024-001", + "total_amount": "$1,250.00" +[Example image content] + +<> +Document: INVOICE #INV-2024-002... +``` + +## Next Steps + +After deploying the dynamic-few shot: + +1. **Populate example dataset** with representative documents +2. **Test similarity search** with sample documents +3. **Monitor performance** and adjust TOP_K as needed +4. **Integrate with IDP** using the Lambda ARN +5. **Evaluate accuracy improvements** with few-shot examples + +The dynamic-few shot enables powerful few-shot learning capabilities while leveraging efficient vector similarity search for dynamic example selection. \ No newline at end of file diff --git a/plugins/dynamic-few-shot-lambda/notebooks/fcc_invoices_dataset_import.ipynb b/plugins/dynamic-few-shot-lambda/notebooks/fcc_invoices_dataset_import.ipynb new file mode 100644 index 00000000..2dc1fdce --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/notebooks/fcc_invoices_dataset_import.ipynb @@ -0,0 +1,761 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FCC Invoices Dataset Import to S3 Vector store\n", + "\n", + "This notebook demonstrates how to import the FCC invoices (REALKIE) dataset into S3 Vectors for use with the dynamic few-shot Lambda function.\n", + "\n", + "The FCC invoices dataset contains invoice documents that can be used as few-shot examples for document extraction tasks.\n", + "\n", + "## Process Overview:\n", + "\n", + "1. **Load FCC Invoices Dataset** - Sync and load the dataset using load_dataset()\n", + "2. **Generate Embeddings** - Create multimodal embeddings using Amazon Nova\n", + "3. **Upload to S3 Vectors** - Store embeddings and metadata in S3 Vectors index\n", + "4. **Verify Import** - Test similarity search functionality\n", + "\n", + "> **Note**: This notebook requires AWS credentials with permissions for Bedrock, S3, and S3 Vectors services." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's make sure that modules are autoreloaded\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "ROOTDIR=\"../../../\"\n", + "# First uninstall existing package (to ensure we get the latest version)\n", + "%pip uninstall -y idp_common\n", + "\n", + "# Install the IDP common package with all components in development mode\n", + "%pip install -q -e \"{ROOTDIR}/lib/idp_common_pkg[dev, all]\"\n", + "\n", + "# Note: We can also install specific components like:\n", + "# %pip install -q -e \"{ROOTDIR}/lib/idp_common_pkg[ocr,classification,extraction,evaluation]\"\n", + "\n", + "# Check installed version\n", + "%pip show idp_common | grep -E \"Version|Location\"\n", + "\n", + "# Install required packages\n", + "%pip install -q pillow tqdm pandas datasets matplotlib\n", + "\n", + "# Optionally use a .env file fxor environment variables\n", + "try:\n", + " from dotenv import load_dotenv\n", + " load_dotenv() \n", + "except ImportError:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Import Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import subprocess\n", + "from pathlib import Path\n", + "from typing import Dict, List, Any\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import io\n", + "\n", + "import boto3\n", + "from datasets import load_dataset\n", + "\n", + "# Import IDP common modules\n", + "from idp_common import bedrock\n", + "\n", + "print(\"Libraries imported successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Configure S3 Vectors and Bedrock" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration - Update these values from the IDP stack in CloudFormation Resources tab\n", + "GENAIIDP_S3_WORKING_BUCKET = \"\" # From IDP stack Resources tab\n", + "\n", + "S3_VECTORS_BUCKET = \"genaiidp-dynamic-few-shot\"\n", + "S3_VECTORS_INDEX = \"documents\"\n", + "EMBEDDING_MODEL_ID = \"amazon.nova-2-multimodal-embeddings-v1:0\"\n", + "EMBEDDING_DIMENSIONS = 3072\n", + "\n", + "# Initialize clients\n", + "s3vectors_client = boto3.client('s3vectors')\n", + "s3_client = boto3.client('s3')\n", + "bedrock_client = bedrock.BedrockClient()\n", + "\n", + "print(f\"Configured for dataset S3 Bucket: {GENAIIDP_S3_WORKING_BUCKET}\")\n", + "print(f\"Configured for S3 Vectors bucket: {S3_VECTORS_BUCKET}\")\n", + "print(f\"Configured for S3 Vectors index: {S3_VECTORS_INDEX}\")\n", + "print(f\"Using embedding model: {EMBEDDING_MODEL_ID}\")\n", + "print(f\"Using embedding dimensions: {EMBEDDING_DIMENSIONS}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Load FCC Invoices Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sync FCC invoices dataset from S3\n", + "print(\"Syncing FCC invoices dataset from S3...\")\n", + "\n", + "# Configuration for this dataset\n", + "CLASS_LABEL = 'Invoice'\n", + "\n", + "# Create datasets directory\n", + "dataset_root_dir = Path('../datasets')\n", + "dataset_root_dir.mkdir(exist_ok=True)\n", + "\n", + "# Dataset directory\n", + "dataset_dir = dataset_root_dir / 'fcc_invoices'\n", + "\n", + "# Sync dataset from S3 using AWS CLI with Wasabi endpoint\n", + "if not dataset_dir.exists() or not any(dataset_dir.iterdir()):\n", + " print(\"Syncing dataset from S3...\")\n", + " sync_command = [\n", + " 'aws', 's3', 'sync',\n", + " 's3://project-fruitfly/fcc_invoices',\n", + " str(dataset_dir),\n", + " '--endpoint-url=https://s3.us-east-2.wasabisys.com',\n", + " '--no-sign-request'\n", + " ]\n", + " \n", + " try:\n", + " result = subprocess.run(sync_command, capture_output=True, text=True, check=True)\n", + " print(f\"Dataset synced successfully to {dataset_dir}\")\n", + " print(f\"Sync output: {result.stdout}\")\n", + " except subprocess.CalledProcessError as e:\n", + " print(f\"Error syncing dataset: {e}\")\n", + " print(f\"Error output: {e.stderr}\")\n", + " raise\n", + "else:\n", + " print(f\"Using existing dataset at {dataset_dir}\")\n", + "\n", + "# Load the training dataset using load_dataset\n", + "print(\"Loading training dataset...\")\n", + "try:\n", + " # Load dataset from local directory\n", + " dataset = load_dataset('csv', data_dir=str(dataset_dir), split='train')\n", + " print(f\"Loaded dataset with {len(dataset)} samples\")\n", + " \n", + " # Show sample information\n", + " if len(dataset) > 0:\n", + " sample = dataset[0]\n", + " print(f\"Sample keys: {list(sample.keys())}\")\n", + " if 'image' in sample:\n", + " print(f\"Sample image size: {sample['image'].size}\")\n", + " \n", + "except Exception as e:\n", + " print(f\"Error loading dataset: {e}\")\n", + " # Fallback: list files in directory\n", + " image_files = list(dataset_dir.glob('**/*.jpg')) + list(dataset_dir.glob('**/*.png'))\n", + " print(f\"Found {len(image_files)} image files in directory\")\n", + " if image_files:\n", + " print(f\"Sample image: {image_files[0].name}\")\n", + " print(f\"Image file size: {image_files[0].stat().st_size} bytes\")\n", + "\n", + "print(f\"Class label: {CLASS_LABEL}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Process Dataset and Generate Embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def upload_image_to_s3(image_bytes: bytes, s3_key: str) -> str:\n", + " \"\"\"Upload image to S3 and return S3 URI.\"\"\"\n", + " s3_client.put_object(\n", + " Bucket=GENAIIDP_S3_WORKING_BUCKET,\n", + " Key=s3_key,\n", + " Body=image_bytes,\n", + " ContentType='image/jpeg'\n", + " )\n", + " return f\"s3://{GENAIIDP_S3_WORKING_BUCKET}/{s3_key}\"\n", + "\n", + "def load_csv_labels():\n", + " \"\"\"Load the CSV file with labels and metadata.\"\"\"\n", + " csv_path = dataset_dir / 'train.csv'\n", + " if csv_path.exists():\n", + " try:\n", + " df = pd.read_csv(csv_path)\n", + " print(f\"Loaded CSV with {len(df)} rows\")\n", + " return df\n", + " except Exception as e:\n", + " print(f\"Error loading CSV: {e}\")\n", + " return None\n", + " else:\n", + " print(f\"CSV file not found at {csv_path}\")\n", + " return None\n", + "\n", + "def match_image_to_csv_row(image_path: str, csv_df: pd.DataFrame):\n", + " \"\"\"Match an image path to the corresponding CSV row.\"\"\"\n", + " if csv_df is None:\n", + " return None\n", + " \n", + " # Extract the image filename from the path\n", + " image_name = Path(image_path).name\n", + " \n", + " # Look for matching rows in the CSV\n", + " for idx, row in csv_df.iterrows():\n", + " image_files_str = row.get('image_files', '')\n", + " if image_name in image_files_str:\n", + " return row\n", + " \n", + " return None\n", + "\n", + "def get_image_bytes_from_file(image_path):\n", + " \"\"\"Read image file directly as bytes.\"\"\"\n", + " with open(image_path, 'rb') as f:\n", + " return f.read()\n", + "\n", + "def create_sample_attributes_prompt() -> str:\n", + " \"\"\"Create a sample attributes prompt for FCC invoices based on the actual schema.\"\"\"\n", + " # Updated to match the actual FCC invoices dataset structure and expected JSON schema\n", + " attributes_prompt = \"\"\"expected attributes are:\n", + " \"Agency\": \"Great American Media\",\n", + " \"Advertiser\": \"ISS/HOUSE MAJ PAC\", \n", + " \"GrossTotal\": 94700.00,\n", + " \"PaymentTerms\": \"Cash In Advance\",\n", + " \"AgencyCommission\": 14205.00,\n", + " \"NetAmountDue\": 80495.00,\n", + " \"LineItems\": [\n", + " {\n", + " \"LineItemDescription\": \"TODAY IN FLORIDA @9PM\",\n", + " \"LineItemStartDate\": \"10/18/2016\", \n", + " \"LineItemEndDate\": null,\n", + " \"LineItemDays\": [\"T\"],\n", + " \"LineItemRate\": 500.00\n", + " },\n", + " {\n", + " \"LineItemDescription\": \"CH 7 NEWS @ 10PM\",\n", + " \"LineItemStartDate\": \"10/18/2016\",\n", + " \"LineItemEndDate\": null, \n", + " \"LineItemDays\": [\"T\"],\n", + " \"LineItemRate\": 3200.00\n", + " }\n", + " ]\n", + " \"\"\".strip()\n", + " return attributes_prompt\n", + "\n", + "def parse_ground_truth_labels(labels_json_str: str) -> Dict:\n", + " \"\"\"Parse ground truth labels from the dataset and convert to expected format.\"\"\"\n", + " import json\n", + " \n", + " try:\n", + " labels = json.loads(labels_json_str)\n", + " except (json.JSONDecodeError, TypeError):\n", + " return None\n", + " \n", + " # Initialize the result structure\n", + " result = {\n", + " \"Agency\": None,\n", + " \"Advertiser\": None,\n", + " \"GrossTotal\": None,\n", + " \"PaymentTerms\": None,\n", + " \"AgencyCommission\": None,\n", + " \"NetAmountDue\": None,\n", + " \"LineItems\": []\n", + " }\n", + " \n", + " # Group line items by their properties\n", + " line_items = {}\n", + " \n", + " for label in labels:\n", + " label_type = label.get('label', '')\n", + " text = label.get('text', '')\n", + " \n", + " # Map top-level fields\n", + " if label_type == 'Agency':\n", + " result['Agency'] = text\n", + " elif label_type == 'Advertiser':\n", + " result['Advertiser'] = text\n", + " elif label_type == 'Gross Total':\n", + " try:\n", + " result['GrossTotal'] = float(text.replace(',', '').replace('$', ''))\n", + " except ValueError:\n", + " result['GrossTotal'] = text\n", + " elif label_type == 'Net Amount Due':\n", + " try:\n", + " result['NetAmountDue'] = float(text.replace(',', '').replace('$', ''))\n", + " except ValueError:\n", + " result['NetAmountDue'] = text\n", + " elif label_type == 'Payment Terms':\n", + " result['PaymentTerms'] = text\n", + " elif label_type == 'Agency Commission':\n", + " try:\n", + " result['AgencyCommission'] = float(text.replace(',', '').replace('$', ''))\n", + " except ValueError:\n", + " result['AgencyCommission'] = text\n", + " \n", + " # Handle line items (group by position or create separate items)\n", + " elif label_type.startswith('Line Item - '):\n", + " field_name = label_type.replace('Line Item - ', '')\n", + " start_pos = label.get('start', 0)\n", + " \n", + " # Use start position to group related line item fields\n", + " # Find the closest line item group\n", + " closest_key = None\n", + " min_distance = float('inf')\n", + " \n", + " for key in line_items.keys():\n", + " distance = abs(start_pos - key)\n", + " if distance < min_distance and distance < 1000: # Within reasonable range\n", + " min_distance = distance\n", + " closest_key = key\n", + " \n", + " if closest_key is None:\n", + " closest_key = start_pos\n", + " line_items[closest_key] = {}\n", + " \n", + " # Map field names to expected schema\n", + " if field_name == 'Description':\n", + " line_items[closest_key]['LineItemDescription'] = text\n", + " elif field_name == 'Start Date':\n", + " line_items[closest_key]['LineItemStartDate'] = text\n", + " elif field_name == 'End Date':\n", + " line_items[closest_key]['LineItemEndDate'] = text if text else None\n", + " elif field_name == 'Rate':\n", + " try:\n", + " line_items[closest_key]['LineItemRate'] = float(text.replace(',', '').replace('$', ''))\n", + " except ValueError:\n", + " line_items[closest_key]['LineItemRate'] = text\n", + " elif field_name == 'Days':\n", + " # Convert day codes to day names\n", + " day_mapping = {\n", + " 'M': 'M', 'T': 'T', 'W': 'W', 'Th': 'Th', 'F': 'F', 'S': 'S', 'Su': 'Su',\n", + " '1': 'M', '2': 'T', '3': 'W', '4': 'Th', '5': 'F', '6': 'S', '7': 'Su'\n", + " }\n", + " days = []\n", + " for char in text:\n", + " if char in day_mapping and char != '-':\n", + " mapped_day = day_mapping[char]\n", + " if mapped_day not in days:\n", + " days.append(mapped_day)\n", + " line_items[closest_key]['LineItemDays'] = days\n", + " \n", + " # Convert line items dict to list\n", + " result['LineItems'] = list(line_items.values())\n", + " \n", + " return result\n", + "\n", + "def create_metadata(s3_image_uri: str, sample_data: Dict = None) -> Dict:\n", + " \"\"\"Create metadata for S3 Vectors entry.\"\"\"\n", + " class_prompt = f\"This is an example of the class '{CLASS_LABEL}'\"\n", + " \n", + " # If we have actual sample data with labels, use it to create a more accurate attributes prompt\n", + " if sample_data and 'labels' in sample_data:\n", + " parsed_labels = parse_ground_truth_labels(sample_data['labels'])\n", + " if parsed_labels:\n", + " attributes_prompt = f\"expected attributes are: {json.dumps(parsed_labels, indent=2)}\"\n", + " else:\n", + " attributes_prompt = create_sample_attributes_prompt()\n", + " else:\n", + " attributes_prompt = create_sample_attributes_prompt()\n", + "\n", + " return {\n", + " \"classLabel\": CLASS_LABEL,\n", + " \"classPrompt\": class_prompt,\n", + " \"attributesPrompt\": attributes_prompt,\n", + " \"imagePath\": s3_image_uri,\n", + " }\n", + "\n", + "print(\"Helper functions defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Import Dataset to S3 Vectors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Process a subset of the dataset (adjust as needed)\n", + "MAX_SAMPLES = 250 # Adjust this number based on your needs\n", + "BATCH_SIZE = 15 # Adjust this number based on your needs\n", + "\n", + "# Load the CSV labels (this contains the image_files information)\n", + "csv_df = load_csv_labels()\n", + "if csv_df is None:\n", + " print(\"Failed to load CSV data. Exiting.\")\n", + " raise Exception(\"CSV loading failed\")\n", + "\n", + "samples_to_process = min(MAX_SAMPLES, len(csv_df))\n", + "print(f\"Processing {samples_to_process} samples from FCC invoices CSV data...\")\n", + "\n", + "vectors_to_upload = []\n", + "failed_samples = []\n", + "\n", + "for i in tqdm(range(samples_to_process), desc=\"Processing samples\"):\n", + " try:\n", + " csv_row = csv_df.iloc[i]\n", + " \n", + " # Get image files from the CSV row\n", + " image_files_str = csv_row.get('image_files', '')\n", + " if not image_files_str:\n", + " print(f\"No image files found for sample {i}\")\n", + " failed_samples.append(i)\n", + " continue\n", + " \n", + " # Parse the image files array (it's stored as a JSON string)\n", + " import json\n", + " try:\n", + " image_files = json.loads(image_files_str)\n", + " except json.JSONDecodeError:\n", + " print(f\"Failed to parse image_files for sample {i}: {image_files_str}\")\n", + " failed_samples.append(i)\n", + " continue\n", + " \n", + " # Use the first image file (or you could process all images)\n", + " if not image_files:\n", + " print(f\"Empty image_files array for sample {i}\")\n", + " failed_samples.append(i)\n", + " continue\n", + " \n", + " # Load the first image file\n", + " image_file_path = image_files[0]\n", + " full_image_path = dataset_root_dir / image_file_path\n", + " \n", + " if not full_image_path.exists():\n", + " print(f\"Image file not found: {full_image_path}\")\n", + " failed_samples.append(i)\n", + " continue\n", + " \n", + " # Load image file as bytes\n", + " image_bytes = get_image_bytes_from_file(full_image_path)\n", + "\n", + " # Upload image to S3\n", + " s3_key = f\"fcc_invoices/sample_{i:06d}.jpg\"\n", + " s3_image_uri = upload_image_to_s3(image_bytes, s3_key)\n", + " \n", + " # Generate embedding\n", + " embedding = bedrock_client.generate_embedding(\n", + " image_source=image_bytes,\n", + " model_id=EMBEDDING_MODEL_ID,\n", + " dimensions=EMBEDDING_DIMENSIONS\n", + " )\n", + " \n", + " # Create metadata using the CSV row data\n", + " sample_data = {'labels': csv_row.get('labels')}\n", + " metadata = create_metadata(s3_image_uri, sample_data)\n", + "\n", + " # Prepare vector for upload\n", + " vector_entry = {\n", + " \"key\": f\"fcc_invoices_sample_{i:06d}\",\n", + " \"data\": {\"float32\": embedding},\n", + " \"metadata\": metadata\n", + " }\n", + "\n", + " vectors_to_upload.append(vector_entry)\n", + " \n", + " # Upload in batches to avoid memory issues\n", + " if len(vectors_to_upload) >= BATCH_SIZE:\n", + " print(f\"\\nUploading batch of {len(vectors_to_upload)} vectors...\")\n", + " response = s3vectors_client.put_vectors(\n", + " vectorBucketName=S3_VECTORS_BUCKET,\n", + " indexName=S3_VECTORS_INDEX,\n", + " vectors=vectors_to_upload\n", + " )\n", + " print(f\"Batch upload response: {response.get('ResponseMetadata', {}).get('HTTPStatusCode')}\")\n", + " vectors_to_upload = [] # Clear batch\n", + " \n", + " except Exception as e:\n", + " print(f\"\\nFailed to process sample {i}: {e}\")\n", + " failed_samples.append(i)\n", + " continue\n", + "\n", + "# Upload remaining vectors\n", + "if vectors_to_upload:\n", + " print(f\"\\nUploading final batch of {len(vectors_to_upload)} vectors...\")\n", + " response = s3vectors_client.put_vectors(\n", + " vectorBucketName=S3_VECTORS_BUCKET,\n", + " indexName=S3_VECTORS_INDEX,\n", + " vectors=vectors_to_upload\n", + " )\n", + " print(f\"Final batch upload response: {response.get('ResponseMetadata', {}).get('HTTPStatusCode')}\")\n", + "\n", + "print(f\"\\nImport completed!\")\n", + "print(f\"Successfully processed: {samples_to_process - len(failed_samples)} samples from CSV data\")\n", + "print(f\"Failed samples: {len(failed_samples)}\")\n", + "if failed_samples:\n", + " print(f\"Failed sample indices: {failed_samples[:10]}...\") # Show first 10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Verify Import with Similarity Search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load test split for similarity search verification\n", + "test_dataset = load_dataset('csv', data_dir=str(dataset_dir), split='test')\n", + "print(f\"Loaded test dataset with {len(test_dataset)} samples\")\n", + "\n", + "if test_dataset is not None and len(test_dataset) > 0:\n", + " # Use the first sample from test split\n", + " test_sample_index = 0\n", + " test_csv_row = test_dataset[test_sample_index]\n", + " \n", + " # Get test image from CSV row\n", + " test_image_files_str = test_csv_row.get('image_files', '')\n", + " if test_image_files_str:\n", + " try:\n", + " test_image_files = json.loads(test_image_files_str)\n", + " if test_image_files:\n", + " test_image_path = dataset_root_dir / test_image_files[0]\n", + " if test_image_path.exists():\n", + " test_image_bytes = get_image_bytes_from_file(test_image_path)\n", + " print(f\"Loaded test image: {test_image_files[0]}\")\n", + " else:\n", + " print(f\"Test image file not found: {test_image_path}\")\n", + " test_image_bytes = None\n", + " else:\n", + " print(\"Empty image_files array in test sample\")\n", + " test_image_bytes = None\n", + " except (json.JSONDecodeError, IndexError) as e:\n", + " print(f\"Failed to parse test image_files: {e}\")\n", + " test_image_bytes = None\n", + " else:\n", + " print(\"No image_files found in test sample\")\n", + " test_image_bytes = None\n", + "else:\n", + " print(\"Test split is empty or could not be loaded\")\n", + " test_image_bytes = None\n", + "\n", + "if test_image_bytes is not None:\n", + " print(f\"\\nTesting similarity search with test sample {test_sample_index}...\")\n", + "\n", + " # Generate embedding for test image\n", + " test_embedding = bedrock_client.generate_embedding(\n", + " image_source=test_image_bytes,\n", + " model_id=EMBEDDING_MODEL_ID,\n", + " dimensions=EMBEDDING_DIMENSIONS\n", + " )\n", + "else:\n", + " print(\"No test image available for similarity search verification.\")\n", + " test_embedding = None\n", + "\n", + "if test_embedding is not None:\n", + " # Query S3 Vectors for similar examples\n", + " response = s3vectors_client.query_vectors(\n", + " vectorBucketName=S3_VECTORS_BUCKET,\n", + " indexName=S3_VECTORS_INDEX,\n", + " queryVector={\"float32\": test_embedding},\n", + " topK=5,\n", + " returnDistance=True,\n", + " returnMetadata=True\n", + " )\n", + "\n", + " print(f\"\\nFound {len(response['vectors'])} similar examples:\")\n", + " for i, vector in enumerate(response['vectors']):\n", + " distance = vector.get('distance', 'N/A')\n", + " key = vector.get('key', 'N/A')\n", + " metadata = vector.get('metadata', {})\n", + " class_label = metadata.get('classLabel', 'N/A')\n", + " class_prompt = metadata.get('classPrompt', 'N/A')\n", + " attributes_prompt = metadata.get('attributesPrompt', 'N/A')\n", + " image_path = metadata.get('imagePath', 'N/A')\n", + " \n", + " print(f\" {i+1}. Key: {key}\")\n", + " print(f\" Distance: {distance:.4f}\")\n", + " print(f\" Class Label: {class_label}\")\n", + " print(f\" Class Prompt: {class_prompt}\")\n", + " print(f\" Attributes Prompt: {attributes_prompt[:100]}...\") # Truncate for readability\n", + " print(f\" Image Path: {image_path}\")\n", + " print()\n", + "else:\n", + " print(\"Skipping similarity search - no test embedding available.\")\n", + "\n", + "# Display source image and found similar images\n", + "if test_image_bytes is not None and 'response' in locals() and response.get('vectors'):\n", + " import matplotlib.pyplot as plt\n", + " from PIL import Image as PILImage\n", + " import io\n", + " \n", + " # Calculate number of images to display (source + top similar images)\n", + " num_similar = min(3, len(response['vectors'])) # Show top 3 similar images\n", + " total_images = 1 + num_similar # Source + similar images\n", + " \n", + " # Create subplot layout\n", + " fig, axes = plt.subplots(1, total_images, figsize=(5 * total_images, 6))\n", + " if total_images == 1:\n", + " axes = [axes] # Make it iterable for single image\n", + " \n", + " # Display source image\n", + " source_img = PILImage.open(io.BytesIO(test_image_bytes))\n", + " axes[0].imshow(source_img)\n", + " axes[0].set_title(f'Source Image (Test Sample {test_sample_index})', fontsize=12, fontweight='bold')\n", + " axes[0].axis('off')\n", + " \n", + " # Display similar images\n", + " for i, vector in enumerate(response['vectors'][:num_similar]):\n", + " try:\n", + " # Get image path from metadata\n", + " metadata = vector.get('metadata', {})\n", + " image_s3_path = metadata.get('imagePath', '')\n", + " distance = vector.get('distance', 0)\n", + " \n", + " if image_s3_path:\n", + " # Extract S3 key from the full S3 URI\n", + " s3_key = image_s3_path.replace(f's3://{GENAIIDP_S3_WORKING_BUCKET}/', '')\n", + " \n", + " # Download image from S3\n", + " try:\n", + " response_obj = s3_client.get_object(Bucket=GENAIIDP_S3_WORKING_BUCKET, Key=s3_key)\n", + " image_data = response_obj['Body'].read()\n", + " similar_img = PILImage.open(io.BytesIO(image_data))\n", + " \n", + " # Display the image\n", + " axes[i + 1].imshow(similar_img)\n", + " axes[i + 1].set_title(f'Similar #{i+1}\\nDistance: {distance:.3f}', fontsize=10)\n", + " axes[i + 1].axis('off')\n", + " \n", + " except Exception as e:\n", + " # If can't load from S3, show placeholder\n", + " axes[i + 1].text(0.5, 0.5, f'Image not available\\n{str(e)[:50]}...', \n", + " ha='center', va='center', transform=axes[i + 1].transAxes)\n", + " axes[i + 1].set_title(f'Similar #{i+1}\\nDistance: {distance:.3f}', fontsize=10)\n", + " axes[i + 1].axis('off')\n", + " else:\n", + " # No image path available\n", + " axes[i + 1].text(0.5, 0.5, 'No image path', ha='center', va='center', \n", + " transform=axes[i + 1].transAxes)\n", + " axes[i + 1].set_title(f'Similar #{i+1}\\nDistance: {distance:.3f}', fontsize=10)\n", + " axes[i + 1].axis('off')\n", + " \n", + " except Exception as e:\n", + " print(f'Error displaying similar image {i+1}: {e}')\n", + " axes[i + 1].text(0.5, 0.5, f'Error: {str(e)[:30]}...', ha='center', va='center', \n", + " transform=axes[i + 1].transAxes)\n", + " axes[i + 1].set_title(f'Similar #{i+1}', fontsize=10)\n", + " axes[i + 1].axis('off')\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + " \n", + " print(f'\\nDisplayed source image and top {num_similar} similar images from the vector store.')\n", + " \n", + "else:\n", + " print('No images to display - either no test image was loaded or no similar images were found.')\n", + " if test_image_bytes is None:\n", + " print('Reason: No test image available')\n", + " elif 'response' not in locals():\n", + " print('Reason: No similarity search was performed')\n", + " elif not response.get('vectors'):\n", + " print('Reason: No similar images found in vector store')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Summary and Next Steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=== Few-shot Dataset Import Summary ===\")\n", + "print(f\"✅ Dataset: FCC Invoices (REALKIE)\")\n", + "print(f\"✅ Samples processed: {samples_to_process - len(failed_samples) if 'samples_to_process' in locals() and 'failed_samples' in locals() else 'N/A'}\")\n", + "print(f\"✅ S3 Vectors Bucket: {S3_VECTORS_BUCKET}\")\n", + "print(f\"✅ S3 Vectors Index: {S3_VECTORS_INDEX}\")\n", + "print(f\"✅ Images stored in: s3://{GENAIIDP_S3_WORKING_BUCKET}/fcc_invoices/\")\n", + "print(f\"✅ Embedding Model: {EMBEDDING_MODEL_ID}\")\n", + "print(f\"✅ Similarity search verified\")\n", + "\n", + "print(\"\\n=== Next Steps ===\")\n", + "print(\"1. ✅ Updated attributes mapping to match actual FCC invoices dataset structure\")\n", + "print(\"2. ✅ Added ground truth label parsing from CSV data\")\n", + "print(\"3. Configure your IDP extraction to use the dynamic few-shot Lambda ARN\")\n", + "print(\"4. Test document processing with few-shot examples!\")\n", + "print(\"5. Fine-tune the label parsing logic if needed based on your specific use case\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/plugins/dynamic-few-shot-lambda/samconfig.toml b/plugins/dynamic-few-shot-lambda/samconfig.toml new file mode 100644 index 00000000..ce714fd8 --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/samconfig.toml @@ -0,0 +1,10 @@ +version = 0.1 + +[default.deploy.parameters] +stack_name = "IDP-dynamic-few-shot" +resolve_s3 = true +s3_prefix = "IDP-dynamic-few-shot" +region = "us-east-1" +capabilities = "CAPABILITY_IAM" +disable_rollback = true +image_repositories = [] diff --git a/plugins/dynamic-few-shot-lambda/src/IDP-dynamic-few-shot.py b/plugins/dynamic-few-shot-lambda/src/IDP-dynamic-few-shot.py new file mode 100644 index 00000000..1cf760e5 --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/src/IDP-dynamic-few-shot.py @@ -0,0 +1,425 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +""" +Lambda function to provide examples with ground truth data based on S3 Vectors lookup. + +Key Features Demonstrated: +- Dynamically retrieve similar examples based on document content using vector similarity search +- Provide few-shot examples to improve extraction accuracy through example-based prompting +- Leverage S3 Vectors for efficient similarity search across large example datasets +- Integrate multimodal embeddings using Amazon Nova models for image-based similarity +- Customize example selection based on document characteristics and business rules +""" + +import json +import logging +import base64 +import boto3 +import os + +from idp_common import bedrock, s3 +from idp_common.bedrock import format_prompt + +from typing import Any + +logger = logging.getLogger(__name__) +level = logging.getLevelName(os.environ.get("LOG_LEVEL", "INFO")) +logger.setLevel(level) + +# Parse environment variables with error handling +try: + S3VECTOR_BUCKET = os.environ["S3VECTOR_BUCKET"] + S3VECTOR_INDEX = os.environ["S3VECTOR_INDEX"] + S3VECTOR_DIMENSIONS = int(os.environ["S3VECTOR_DIMENSIONS"]) + MODEL_ID = os.environ["MODEL_ID"] + TOP_K = int(os.environ["TOP_K"]) + THRESHOLD = float(os.environ["THRESHOLD"]) +except (KeyError, ValueError, IndexError) as e: + logger.error(f"Failed to parse environment variables: {e}") + raise + +# Initialize clients +s3vectors = boto3.client("s3vectors") +bedrock_client = bedrock.BedrockClient() + + +def lambda_handler(event, context): + """ + Process a document to find similar examples using S3 Vectors similarity search. + This function will expand {FEW_SHOT_EXAMPLES} in the extraction prompt to examples + found in S3 Vectors lookup. + """ + + try: + logger.info("=== DYNAMIC FEW-SHOT LAMBDA INVOKED ===") + logger.debug(f"Complete input event: {json.dumps(event, indent=2)}") + + # Extract key information from the payload + config = event.get("config", {}) + placeholders = event.get("prompt_placeholders", {}) + default_content = event.get("default_task_prompt_content", []) + document = event.get("serialized_document", {}) + + document_class = placeholders.get("DOCUMENT_CLASS", "") + document_text = placeholders.get("DOCUMENT_TEXT", "") + document_image_uris = placeholders.get("DOCUMENT_IMAGE", []) + document_id = document.get("id", "unknown") + + # Log extraction config details + extraction_config = config.get("extraction", {}) + logger.info(f"=== EXTRACTION CONFIG ===") + logger.info(f"Model: {extraction_config.get('model', 'Not specified')}") + logger.info(f"Temperature: {extraction_config.get('temperature', 'Not specified')}") + logger.info(f"Max tokens: {extraction_config.get('max_tokens', 'Not specified')}") + logger.info(f"Custom Lambda ARN: {extraction_config.get('custom_prompt_lambda_arn', 'Not specified')}") + + # Default system prompt from config + default_system_prompt = config.get("extraction", {}).get("system_prompt", "") + logger.info(f"Default system prompt length: {len(default_system_prompt)} characters") + default_task_prompt = config.get("extraction", {}).get("task_prompt", "") + logger.info(f"Default task prompt length: {len(default_task_prompt)} characters") + + logger.info(f"=== HANDLE INPUT DOCUMENT ===") + + # Handle input document + result = _handle_input_document(placeholders, default_system_prompt, default_task_prompt) + + # Log complete output structure + logger.info(f"=== OUTPUT ANALYSIS ===") + logger.info(f"Output keys: {list(result.keys())}") + logger.info(f"System prompt length: {len(result.get('system_prompt', ''))}") + logger.info(f"System prompt (first 200 chars): {result.get('system_prompt', '')[:200]}...") + + task_content = result.get('task_prompt_content', []) + logger.info(f"Task prompt content items: {len(task_content)}") + for i, item in enumerate(task_content[:3]): # Log first 3 items + logger.info(f"Content item {i}: keys={list(item.keys())}") + if 'text' in item: + logger.info(f" Text length: {len(item['text'])} characters") + logger.info(f" Text sample (first 150 chars): {item['text'][:150]}...") + if 'image_uri' in item: + logger.info(f" Image URI: {item['image_uri']}") + + if len(task_content) > 3: + logger.info(f" ... and {len(task_content) - 3} more content items") + + logger.debug(f"Complete result output: {json.dumps(result, indent=2)}") + logger.info("=== DYNAMIC FEW-SHOT LAMBDA COMPLETED ===") + return result + + except Exception as e: + logger.error(f"=== DYNAMIC FEW-SHOT LAMBDA ERROR ===") + logger.error(f"Error type: {type(e).__name__}") + logger.error(f"Error message: {str(e)}") + logger.error( + f"Input event keys: {list(event.keys()) if 'event' in locals() else 'Unknown'}" + ) + # In demo, we'll fail gracefully with detailed error info + raise Exception(f"Dynamic few-shot Lambda failed: {str(e)}") + +def _handle_input_document(placeholders, default_system_prompt, default_task_prompt): + """ + Handle input request and return custom system_prompt and task_prompt_content + """ + substitutions = { + "DOCUMENT_TEXT": placeholders.get("DOCUMENT_TEXT"), + "DOCUMENT_CLASS": placeholders.get("DOCUMENT_CLASS"), + "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": placeholders.get("ATTRIBUTE_NAMES_AND_DESCRIPTIONS") + } + task_prompt_content = _build_prompt_content( + default_task_prompt, substitutions, placeholders.get("DOCUMENT_IMAGE") + ) + + return { + "system_prompt": default_system_prompt, + "task_prompt_content": task_prompt_content + } + + +def _build_prompt_content( + prompt_template: str, + substitutions: dict[str, Any], + image_content: Any = None, +) -> list[dict[str, Any]]: + """ + Build prompt content array handling FEW_SHOT_EXAMPLES and DOCUMENT_IMAGE placeholders. + + This consolidated method handles all placeholder types and combinations: + - {FEW_SHOT_EXAMPLES}: Inserts few-shot examples from config + - {DOCUMENT_IMAGE}: Inserts images at specific location + - Regular text placeholders: DOCUMENT_TEXT, DOCUMENT_CLASS, etc. + + Args: + prompt_template: The prompt template with optional placeholders + substitutions: Dictionary of placeholder values + image_content: Optional image content to insert (only used with {DOCUMENT_IMAGE}) + + Returns: + List of content items with text and image content properly ordered + """ + content: list[dict[str, Any]] = [] + + # Handle FEW_SHOT_EXAMPLES placeholder first + if "{FEW_SHOT_EXAMPLES}" in prompt_template: + parts = prompt_template.split("{FEW_SHOT_EXAMPLES}") + if len(parts) == 2: + # Process before examples + content.extend( + _build_text_and_image_content(parts[0], substitutions, image_content) + ) + + # Add few-shot examples + content.extend(_build_few_shot_examples_content(image_content)) + + # Process after examples (only pass images if not already used) + image_for_after = ( + None if "{DOCUMENT_IMAGE}" in parts[0] else image_content + ) + content.extend( + _build_text_and_image_content(parts[1], substitutions, image_for_after) + ) + + return content + + # No FEW_SHOT_EXAMPLES, just handle text and images + logger.warn("Missing {FEW_SHOT_EXAMPLES} placeholder in prompt template") + return _build_text_and_image_content(prompt_template, substitutions, image_content) + + +def _build_text_and_image_content( + prompt_template: str, + substitutions: dict[str, Any], + image_content: Any = None, +) -> list[dict[str, Any]]: + """ + Build content array with text and optionally images based on DOCUMENT_IMAGE placeholder. + + Args: + prompt_template: Template that may contain {DOCUMENT_IMAGE} + substitutions: Dictionary of placeholder values + image_content: Optional image content + + Returns: + List of content items + """ + content: list[dict[str, Any]] = [] + + if "{DOCUMENT_IMAGE}" in prompt_template: + parts = prompt_template.split("{DOCUMENT_IMAGE}") + if len(parts) == 2: + # Add text before image + before_text = _prepare_prompt_from_template( + parts[0], substitutions, required_placeholders=[] + ) + if before_text.strip(): + content.append({"text": before_text}) + + # Add images + if image_content: + for image_uri in image_content: + # Load image content + if image_uri.startswith("s3://"): + # Direct S3 URI + image_bytes = s3.get_binary_content(image_uri) + else: + raise ValueError(f"Invalid file path {image_path} - expecting S3 path") + + # Convert bytes to base64 string + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + content.append({"image_base64": image_base64}) + + # Add text after image + after_text = _prepare_prompt_from_template( + parts[1], substitutions, required_placeholders=[] + ) + if after_text.strip(): + content.append({"text": after_text}) + + return content + else: + logger.warning("Invalid DOCUMENT_IMAGE placeholder usage") + + # No image placeholder, just text + task_prompt = _prepare_prompt_from_template( + prompt_template, substitutions, required_placeholders=[] + ) + content.append({"text": task_prompt}) + + return content + + +def _build_few_shot_examples_content(image_content: Any = None) -> list[dict[str, Any]]: + """ + Build content items for few-shot examples from the configuration for a specific class. + + Args: + image_content: Optional document image content + + Returns: + List of content items containing text and image content for examples + """ + content: list[dict[str, Any]] = [] + + image_data = [] + if image_content: + for image_uri in image_content: + # Load image content + if image_uri.startswith("s3://"): + # Direct S3 URI + image_bytes = s3.get_binary_content(image_uri) + else: + raise ValueError(f"Invalid file path {image_path} - expecting S3 path") + + image_data.append(image_bytes) + + examples = _s3vectors_find_similar_items(image_data) + for example in examples: + content.append({"text": example.get("attributesPrompt")}) + + for image_uri in example.get("imageFiles", []): + content.append({"image_uri": image_uri}) + + return content + + +def _prepare_prompt_from_template(prompt_template, substitutions, required_placeholders): + """ + Prepare prompt from template by replacing placeholders with values. + + Args: + prompt_template: The prompt template with placeholders + substitutions: Dictionary of placeholder values + required_placeholders: List of placeholder names that must be present in the template + + Returns: + String with placeholders replaced by values + + Raises: + ValueError: If a required placeholder is missing from the template + """ + + return format_prompt(prompt_template, substitutions, required_placeholders) + + +def _s3vectors_find_similar_items(image_data): + """Find similar items for input""" + # find similar items based on image similarity only + similar_items = {} + for page_image in image_data: + result = _s3vectors_find_similar_items_from_image(page_image) + _merge_examples(similar_items, result) + + # create result set + result = [] + for key, example in similar_items.items(): + metadata = example.get("metadata", {}) + distance = example.get("distance") + attributes_prompt = metadata.get("attributesPrompt") + + # Only process this example if it has a non-empty attributesPrompt + if not attributes_prompt or not attributes_prompt.strip(): + logger.info(f"Skipping example with empty attributesPrompt: {key}") + continue + + attributes = _extract_metadata(metadata, distance) + result.append(attributes) + + # sort results by distance score (lowest to highest - lower is more similar) + sorted_result = sorted( + result, key=lambda example: example["distance"], reverse=False + ) + + # filter result by distance score + filtered_result = [] + for example in sorted_result: + if example["distance"] > THRESHOLD: + logger.info( + f"Skipping example with distance {example['distance']} above threshold {THRESHOLD}: {key}" + ) + else: + filtered_result.append(example) + + return filtered_result + + +def _s3vectors_find_similar_items_from_image(page_image): + """Search for similar items using image query""" + embedding = bedrock_client.generate_embedding( + image_source=page_image, + model_id=MODEL_ID, + dimensions=S3VECTOR_DIMENSIONS, + ) + response = s3vectors.query_vectors( + vectorBucketName=S3VECTOR_BUCKET, + indexName=S3VECTOR_INDEX, + queryVector={"float32": embedding}, + topK=TOP_K, + returnDistance=True, + returnMetadata=True, + ) + logger.debug(f"S3 vectors lookup result: {response['vectors']}") + return response["vectors"] + + +def _merge_examples(examples, new_examples): + """ + Merge in-place new examples into the result list, avoiding duplicates. + + Args: + examples: Dict of existing examples + new_examples: List of new examples to be merged + """ + for new_example in new_examples: + key = new_example["key"] + new_distance = new_example.get("distance", 1.0) + + # update example + if examples.get(key): + existing_distance = examples[key].get("distance", 1.0) + examples[key]["distance"] = min(new_distance, existing_distance) + examples[key]["metadata"] = new_example.get("metadata") + # insert example + else: + examples[key] = { + "distance": new_distance, + "metadata": new_example.get("metadata"), + } + + +def _extract_metadata(metadata, distance): + """Create result object from S3 vectors metadata""" + # Result object attributes + attributes = { + "attributesPrompt": metadata.get("attributesPrompt"), + "classPrompt": metadata.get("classPrompt"), + "imageFiles": _get_image_files_from_s3_path(metadata.get("imagePath")), + "distance": distance, + } + + return attributes + + +def _get_image_files_from_s3_path(image_path): + """ + Get list of image files from an S3 path. + + Args: + image_path: Path to image file, directory, or S3 prefix + + Returns: + List of image file paths/URIs sorted by filename + """ + # Handle S3 URIs + if not image_path.startswith("s3://"): + raise ValueError(f"Invalid file path {image_path} - expecting S3 URI") + + # Check if it's a direct file or a prefix + if image_path.endswith( + (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif", ".webp") + ): + # Direct S3 file + return [image_path] + else: + # S3 prefix - list all images + return s3.list_images_from_path(image_path) diff --git a/plugins/dynamic-few-shot-lambda/src/requirements.txt b/plugins/dynamic-few-shot-lambda/src/requirements.txt new file mode 100644 index 00000000..77b716ca --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/src/requirements.txt @@ -0,0 +1 @@ +../../lib/idp_common_pkg[extraction,docs_service] # extraction module and document service with dependencies diff --git a/plugins/dynamic-few-shot-lambda/template.yml b/plugins/dynamic-few-shot-lambda/template.yml new file mode 100644 index 00000000..25184437 --- /dev/null +++ b/plugins/dynamic-few-shot-lambda/template.yml @@ -0,0 +1,387 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Description: Deploy demo Lambda function for GenAI IDP dynamic few-shot prompting + +Parameters: + + PermissionsBoundaryArn: + Type: String + Default: "" + Description: >- + (Optional) ARN of an existing IAM Permissions Boundary policy to attach to the Lambda execution role. + Leave blank if no Permissions Boundary is required. + AllowedPattern: "^(|arn:aws[a-z-]*::iam::[0-9]{12}:policy/.+)$" + ConstraintDescription: Must be empty or a valid IAM policy ARN + + VectorBucketName: + Type: String + Default: "" + Description: >- + (Optional) Existing S3 vectors bucket used. Provide the name of an existing S3 vectors + bucket here or leave blank to automatically create a new S3 vectors bucket. + + VectorIndexName: + Type: String + Default: "" + Description: >- + (Optional) Existing S3 vectors index used. Provide the name of an existing S3 vectors + index here or leave blank to automatically create a new S3 vectors index. + + ModelId: + Type: String + Default: "amazon.nova-2-multimodal-embeddings-v1:0" + Description: Vector embedding model to use to create meaningful vector representations of documents + + VectorDimensions: + Type: Number + Default: 3072 + Description: Vector embedding length to use, as defined by the embedding model in use + + TopK: + Type: Number + Default: 2 + Description: The number of results to return for each S3 vectors query. + + Threshold: + Type: Number + Default: 0.5 + Description: Filter results exceeding this similarity threshold (lower is more similar) + + LambdaFunctionName: + Type: String + Default: "IDP-dynamic-few-shot" + + DatasetBucketName: + Type: String + Default: "" + Description: >- + (Optional) Existing bucket used for dynamic few-shot datasets. Provide the name of + an existing bucket here or leave blank to automatically create a new bucket. + + # Logging configuration + LogLevel: + Type: String + Default: INFO + AllowedValues: + - DEBUG + - INFO + - WARN + - ERROR + Description: Default logging level + + LogRetentionDays: + Type: Number + Default: 30 + Description: Number of days to retain CloudWatch logs + AllowedValues: + [ + 1, + 3, + 5, + 7, + 14, + 30, + 60, + 90, + 120, + 150, + 180, + 365, + 400, + 545, + 731, + 1827, + 3653, + ] + + # GenAI IDP parameters + IDPS3LoggingBucketName: + Type: String + Description: + IDP LoggingBucket Name, to store access logs for the dataset bucket + + IDPS3OutputBucketName: + Type: String + Description: >- + IDP S3OutputBucketName, to read the documents being processed + + IDPCustomerManagedEncryptionKeyArn: + Type: String + Description: >- + IDP CustomerManagedEncryptionKey ARN, to decrypt documents being read from the output bucket + +Conditions: + HasPermissionsBoundary: !Not [!Equals [!Ref PermissionsBoundaryArn, ""]] + ShouldCreateVectorBucket: !Equals [ !Ref VectorBucketName, "" ] + ShouldCreateVectorIndex: !Equals [ !Ref VectorIndexName, "" ] + ShouldCreateDatasetBucket: !Equals [ !Ref DatasetBucketName, "" ] + +Resources: + + DynamicFewShotFunction: + Type: AWS::Serverless::Function + Metadata: + cfn_nag: + rules_to_suppress: + - id: W89 + reason: "Function does not require VPC access as it only interacts with AWS services via APIs" + - id: W92 + reason: "Function does not require reserved concurrency as it scales based on demand" + - id: W58 + reason: "Function does not require DLQ as processing and retries are handled by the IDP framework" + - id: W11 + reason: "Allow * resource on its permissions policy for CloudWatch metrics" + # checkov:skip=CKV_AWS_116: "Function does not require DLQ as processing and retries are handled by the IDP framework" + # checkov:skip=CKV_AWS_117: "Function does not require VPC access as it only interacts with AWS services via APIs" + # checkov:skip=CKV_AWS_115: "Function does not require reserved concurrency as it scales based on demand" + # checkov:skip=CKV_AWS_173: "Environment variables do not contain sensitive data - only configuration values like feature flags and non-sensitive settings" + Properties: + FunctionName: !Ref LambdaFunctionName + PermissionsBoundary: !If [HasPermissionsBoundary, !Ref PermissionsBoundaryArn, !Ref AWS::NoValue] + CodeUri: ./src + Handler: IDP-dynamic-few-shot.lambda_handler + Runtime: python3.12 + Architectures: + - arm64 + Timeout: 300 + MemorySize: 512 + Description: Lambda function for GenAI IDP dynamic few-shot prompting using S3 Vectors + Environment: + Variables: + LOG_LEVEL: !Ref LogLevel + S3VECTOR_BUCKET: !If + - ShouldCreateVectorBucket + # Error: Requested attribute VectorBucketName must be a readonly property in schema for AWS::S3Vectors::VectorBucket + # - !GetAtt VectorBucket.VectorBucketName + - !Select [1, !Split ["/", !Ref VectorBucket]] + - !Ref VectorBucketName + S3VECTOR_INDEX: !If + - ShouldCreateVectorIndex + # Error: Requested attribute IndexName must be a readonly property in schema for AWS::S3Vectors::Index + # - !GetAtt DocumentsIndex.IndexName + - !Select [3, !Split ["/", !Ref DocumentsIndex]] + - !Ref VectorIndexName + S3VECTOR_DIMENSIONS: !Ref VectorDimensions + MODEL_ID: !Ref ModelId + TOP_K: !Ref TopK + THRESHOLD: !Ref Threshold + LoggingConfig: + LogGroup: !Ref DynamicFewShotLogGroup + # Minimal permissions - only needs basic execution and logging + Policies: + - AWSLambdaBasicExecutionRole + - S3ReadPolicy: + BucketName: !If + - ShouldCreateDatasetBucket + - !Ref DatasetBucket + - !Ref DatasetBucketName + - S3ReadPolicy: + BucketName: !Ref IDPS3OutputBucketName + - Statement: + - Effect: Allow + Action: cloudwatch:PutMetricData + Resource: "*" + - Effect: Allow + Action: + - bedrock:InvokeModel + - bedrock:InvokeModelWithResponseStream + Resource: + - !Sub "arn:${AWS::Partition}:bedrock:*::foundation-model/*" + - !Sub "arn:${AWS::Partition}:bedrock:${AWS::Region}:${AWS::AccountId}:inference-profile/*" + - Effect: Allow + Action: + - s3vectors:GetVectors + - s3vectors:QueryVectors + Resource: + - !If + - ShouldCreateVectorIndex + - !Ref DocumentsIndex + - !If + - ShouldCreateVectorBucket + - !Sub "${VectorBucket}/index/${VectorIndexName}" + - !Sub "arn:${AWS::Partition}:s3vectors:${AWS::Region}:${AWS::AccountId}:bucket/${VectorBucketName}/index/${VectorIndexName}" + - Effect: Allow + Action: + - kms:Decrypt + Resource: + - !GetAtt CustomerManagedEncryptionKey.Arn + - !Ref IDPCustomerManagedEncryptionKeyArn + + DynamicFewShotLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${LambdaFunctionName}" + RetentionInDays: !Ref LogRetentionDays + KmsKeyId: !GetAtt CustomerManagedEncryptionKey.Arn + + VectorBucket: + Type: AWS::S3Vectors::VectorBucket + Condition: ShouldCreateVectorBucket + Properties: + EncryptionConfiguration: + SseType: "aws:kms" + KmsKeyArn: !GetAtt CustomerManagedEncryptionKey.Arn + + DocumentsIndex: + Type: AWS::S3Vectors::Index + Condition: ShouldCreateVectorIndex + Properties: + DataType: "float32" + Dimension: !Ref VectorDimensions + DistanceMetric: "cosine" + MetadataConfiguration: + NonFilterableMetadataKeys: + - "classPrompt" + - "attributesPrompt" + - "imagePath" + VectorBucketName: !If + - ShouldCreateVectorBucket + - !Select [1, !Split ["/", !Ref VectorBucket]] + - VectorBucketName + + DatasetBucket: + Type: AWS::S3::Bucket + Condition: ShouldCreateDatasetBucket + DeletionPolicy: RetainExceptOnCreate + Properties: + BucketEncryption: + ServerSideEncryptionConfiguration: + - ServerSideEncryptionByDefault: + SSEAlgorithm: aws:kms + KMSMasterKeyID: !Ref CustomerManagedEncryptionKey + PublicAccessBlockConfiguration: + BlockPublicAcls: true + BlockPublicPolicy: true + IgnorePublicAcls: true + RestrictPublicBuckets: true + VersioningConfiguration: + Status: Enabled + LoggingConfiguration: + DestinationBucketName: !Ref IDPS3LoggingBucketName + LogFilePrefix: fewshot-dataset-bucket-logs/ + + DatasetBucketPolicy: + Type: AWS::S3::BucketPolicy + Condition: ShouldCreateDatasetBucket + Properties: + Bucket: !Ref DatasetBucket + PolicyDocument: + Version: "2012-10-17" + Statement: + - Sid: EnforceSSLOnly + Effect: Deny + Principal: "*" + Action: "s3:*" + Resource: + - !Sub "${DatasetBucket.Arn}/*" + - !Sub "${DatasetBucket.Arn}" + Condition: + Bool: + "aws:SecureTransport": false + + CustomerManagedEncryptionKey: + Type: AWS::KMS::Key + Metadata: + security-matrix: + rules_to_suppress: + - id: IAM-005 + reason: "No cross-account access - only same account root and AWS services" + - id: KMS-007 + reason: "KMS monitoring not required for this IDP solution - comprehensive CloudWatch monitoring already in place" + - id: KMS-002 + reason: "kms:* permission for account root is standard pattern for administrative access to KMS keys" + Properties: + Description: KMS key for encryption of dynamic few-shot resources + EnableKeyRotation: true + KeyPolicy: + Version: "2012-10-17" + Statement: + - Sid: Enable IAM User Permissions + Effect: Allow + Principal: + AWS: !Sub "arn:${AWS::Partition}:iam::${AWS::AccountId}:root" + Action: kms:* + Resource: "*" + - Sid: Allow lambda to access the Keys + Effect: Allow + Principal: + AWS: !Sub "arn:${AWS::Partition}:iam::${AWS::AccountId}:root" + Action: + - kms:Encrypt + - kms:Decrypt + - kms:ReEncrypt* + - kms:GenerateDataKey* + - kms:DescribeKey + Resource: "*" + - Sid: Allow CloudWatch Logs to use the key + Effect: Allow + Principal: + Service: !Sub "logs.${AWS::URLSuffix}" + Action: + - kms:Encrypt + - kms:Decrypt + - kms:ReEncrypt* + - kms:GenerateDataKey* + - kms:DescribeKey + Resource: "*" + - Sid: Allow S3 Vectors indexing service to use the key + Effect: Allow + Principal: + Service: !Sub "indexing.s3vectors.${AWS::URLSuffix}" + Action: + - kms:Encrypt + - kms:Decrypt + - kms:ReEncrypt* + - kms:GenerateDataKey* + - kms:DescribeKey + Resource: "*" + +Outputs: + + DynamicFewShotFunctionName: + Description: Name of the demo Lambda function + Value: !Ref DynamicFewShotFunction + + DynamicFewShotFunctionArn: + Description: ARN of the demo Lambda function (use this in your GenAIIDP configuration) + Value: !GetAtt DynamicFewShotFunction.Arn + + DynamicFewShotLogGroup: + Description: CloudWatch Log Group for monitoring demo Lambda execution + Value: !Ref DynamicFewShotLogGroup + + VectorBucketName: + Description: S3 Vectors bucket for dynamic few-shot examples + Value: !If + - ShouldCreateVectorBucket + - !Select [1, !Split ["/", !Ref VectorBucket]] + - !Ref VectorBucketName + + VectorIndexName: + Description: S3 Vectors index for dynamic few-shot examples + Value: !If + - ShouldCreateVectorIndex + - !Select [3, !Split ["/", !Ref DocumentsIndex]] + - !Ref VectorIndexName + + DatasetBucket: + Description: S3 bucket for example data sets + Value: !If + - ShouldCreateDatasetBucket + - !Ref DatasetBucket + - !Ref DatasetBucketName + + UsageInstructions: + Description: How to use this Lambda in your IDP configuration + Value: !Sub | + Add this ARN to your extraction config: + extraction: + custom_prompt_lambda_arn: "${DynamicFewShotFunction.Arn}" + + MonitoringLink: + Description: Direct link to CloudWatch logs for this function + Value: !Sub | + https://console.aws.amazon.com/cloudwatch/home?region=${AWS::Region}#logsV2:log-groups/log-group/$252Faws$252Flambda$252F${LambdaFunctionName} \ No newline at end of file