|
| 1 | +""" |
| 2 | +Vision Transformer (ViT) Image Classification Demo |
| 3 | +
|
| 4 | +This module demonstrates how to use a pre-trained Vision Transformer (ViT) model |
| 5 | +from Hugging Face for image classification tasks. |
| 6 | +
|
| 7 | +Vision Transformers apply the transformer architecture (originally designed for NLP) |
| 8 | +to computer vision by splitting images into patches and processing them with |
| 9 | +self-attention mechanisms. |
| 10 | +
|
| 11 | +Requirements: |
| 12 | + - torch |
| 13 | + - transformers |
| 14 | + - Pillow (PIL) |
| 15 | + - requests |
| 16 | +
|
| 17 | +Resources: |
| 18 | + - Paper: https://arxiv.org/abs/2010.11929 |
| 19 | + - Hugging Face: https://huggingface.co/docs/transformers/model_doc/vit |
| 20 | +
|
| 21 | +Example Usage: |
| 22 | + from computer_vision.vision_transformer_demo import classify_image |
| 23 | +
|
| 24 | + # Classify an image from URL |
| 25 | + url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| 26 | + result = classify_image(url) |
| 27 | + print(f"Predicted: {result['label']} (confidence: {result['score']:.2%})") |
| 28 | +
|
| 29 | + # Classify a local image |
| 30 | + result = classify_image("path/to/image.jpg", top_k=3) |
| 31 | + for pred in result['top_k_predictions']: |
| 32 | + print(f"{pred['label']}: {pred['score']:.2%}") |
| 33 | +""" |
| 34 | + |
| 35 | +from __future__ import annotations |
| 36 | + |
| 37 | +import sys |
| 38 | +from io import BytesIO |
| 39 | +from pathlib import Path |
| 40 | +from typing import Any |
| 41 | + |
| 42 | +try: |
| 43 | + import requests |
| 44 | + import torch |
| 45 | + from PIL import Image |
| 46 | + from transformers import ViTForImageClassification, ViTImageProcessor |
| 47 | +except ImportError as e: |
| 48 | + print(f"Error: Missing required dependency: {e.name}") |
| 49 | + print("Install dependencies: pip install torch transformers pillow requests") |
| 50 | + sys.exit(1) |
| 51 | + |
| 52 | + |
| 53 | +def load_image(image_source: str | Path, timeout: int = 10) -> Image.Image: |
| 54 | + """ |
| 55 | + Load an image from a URL or local file path. |
| 56 | +
|
| 57 | + Args: |
| 58 | + image_source: URL string or Path object to the image |
| 59 | + timeout: Network timeout in seconds (default: 10) |
| 60 | +
|
| 61 | + Returns: |
| 62 | + PIL Image object |
| 63 | +
|
| 64 | + Raises: |
| 65 | + TimeoutError: If request times out |
| 66 | + ConnectionError: If URL is unreachable |
| 67 | + FileNotFoundError: If local file doesn't exist |
| 68 | + IOError: If image cannot be opened |
| 69 | +
|
| 70 | + Examples: |
| 71 | + >>> # Test with non-existent file |
| 72 | + >>> try: |
| 73 | + ... load_image("nonexistent_file.jpg") |
| 74 | + ... except FileNotFoundError: |
| 75 | + ... print("File not found") |
| 76 | + File not found |
| 77 | + """ |
| 78 | + if isinstance(image_source, (str, Path)) and str(image_source).startswith( |
| 79 | + ("http://", "https://") |
| 80 | + ): |
| 81 | + try: |
| 82 | + response = requests.get(str(image_source), timeout=timeout) |
| 83 | + response.raise_for_status() |
| 84 | + return Image.open(BytesIO(response.content)).convert("RGB") |
| 85 | + except requests.exceptions.Timeout: |
| 86 | + msg = ( |
| 87 | + f"Request timed out after {timeout} seconds. " |
| 88 | + "Try increasing the timeout parameter." |
| 89 | + ) |
| 90 | + raise TimeoutError(msg) |
| 91 | + except requests.exceptions.RequestException as e: |
| 92 | + msg = f"Failed to download image from URL: {e}" |
| 93 | + raise ConnectionError(msg) from e |
| 94 | + else: |
| 95 | + # Load from local file |
| 96 | + file_path = Path(image_source) |
| 97 | + if not file_path.exists(): |
| 98 | + msg = f"Image file not found: {file_path}" |
| 99 | + raise FileNotFoundError(msg) |
| 100 | + return Image.open(file_path).convert("RGB") |
| 101 | + |
| 102 | + |
| 103 | +def classify_image( |
| 104 | + image_source: str | Path, |
| 105 | + model_name: str = "google/vit-base-patch16-224", |
| 106 | + top_k: int = 1, |
| 107 | +) -> dict[str, Any]: |
| 108 | + """ |
| 109 | + Classify an image using a Vision Transformer model. |
| 110 | +
|
| 111 | + Args: |
| 112 | + image_source: URL or local path to the image |
| 113 | + model_name: Hugging Face model identifier (default: google/vit-base-patch16-224) |
| 114 | + top_k: Number of top predictions to return (default: 1) |
| 115 | +
|
| 116 | + Returns: |
| 117 | + Dictionary containing: |
| 118 | + - label: Predicted class label |
| 119 | + - score: Confidence score (0-1) |
| 120 | + - top_k_predictions: List of top-k predictions (if top_k > 1) |
| 121 | +
|
| 122 | + Raises: |
| 123 | + ValueError: If top_k is less than 1 |
| 124 | + FileNotFoundError: If image file doesn't exist |
| 125 | + ConnectionError: If unable to download from URL |
| 126 | +
|
| 127 | + Examples: |
| 128 | + >>> # Test parameter validation |
| 129 | + >>> try: |
| 130 | + ... classify_image("test.jpg", top_k=0) |
| 131 | + ... except ValueError as e: |
| 132 | + ... print("Invalid top_k") |
| 133 | + Invalid top_k |
| 134 | + """ |
| 135 | + if top_k < 1: |
| 136 | + raise ValueError("top_k must be at least 1") |
| 137 | + # Load image |
| 138 | + image = load_image(image_source) |
| 139 | + |
| 140 | + # Load pre-trained model and processor |
| 141 | + # Using context manager pattern for better resource management |
| 142 | + processor = ViTImageProcessor.from_pretrained(model_name) |
| 143 | + model = ViTForImageClassification.from_pretrained(model_name) |
| 144 | + |
| 145 | + # Preprocess image |
| 146 | + inputs = processor(images=image, return_tensors="pt") |
| 147 | + |
| 148 | + # Perform inference |
| 149 | + with torch.no_grad(): # Disable gradient calculation for inference |
| 150 | + outputs = model(**inputs) |
| 151 | + logits = outputs.logits |
| 152 | + |
| 153 | + # Get predictions |
| 154 | + probabilities = torch.nn.functional.softmax(logits, dim=-1) |
| 155 | + top_k_probs, top_k_indices = torch.topk(probabilities, k=top_k, dim=-1) |
| 156 | + |
| 157 | + # Format results |
| 158 | + predictions = [] |
| 159 | + for prob, idx in zip(top_k_probs[0], top_k_indices[0]): |
| 160 | + predictions.append( |
| 161 | + {"label": model.config.id2label[idx.item()], "score": prob.item()} |
| 162 | + ) |
| 163 | + |
| 164 | + result = { |
| 165 | + "label": predictions[0]["label"], |
| 166 | + "score": predictions[0]["score"], |
| 167 | + "top_k_predictions": predictions if top_k > 1 else None, |
| 168 | + } |
| 169 | + |
| 170 | + return result |
| 171 | + |
| 172 | + |
| 173 | +def main() -> None: |
| 174 | + """ |
| 175 | + Main function demonstrating Vision Transformer usage. |
| 176 | +
|
| 177 | + Downloads a sample image and performs classification. |
| 178 | + """ |
| 179 | + print("Vision Transformer (ViT) Image Classification Demo") |
| 180 | + print("=" * 60) |
| 181 | + |
| 182 | + # Sample image URL (two cats on a couch from COCO dataset) |
| 183 | + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| 184 | + |
| 185 | + print(f"\nLoading image from: {image_url}") |
| 186 | + |
| 187 | + try: |
| 188 | + # Get top-3 predictions |
| 189 | + result = classify_image(image_url, top_k=3) |
| 190 | + |
| 191 | + print(f"\n{'Prediction Results':^60}") |
| 192 | + print("-" * 60) |
| 193 | + print(f"Top Prediction: {result['label']}") |
| 194 | + print(f"Confidence: {result['score']:.2%}") |
| 195 | + |
| 196 | + if result["top_k_predictions"]: |
| 197 | + print(f"\n{'Top 3 Predictions':^60}") |
| 198 | + print("-" * 60) |
| 199 | + for i, pred in enumerate(result["top_k_predictions"], 1): |
| 200 | + print(f"{i}. {pred['label']:<40} {pred['score']:>6.2%}") |
| 201 | + |
| 202 | + # Example with local image (commented out) |
| 203 | + print("\n" + "=" * 60) |
| 204 | + print("To classify a local image, use:") |
| 205 | + print(' result = classify_image("path/to/your/image.jpg")') |
| 206 | + print(" print(f\"Predicted: {result['label']}\")") |
| 207 | + |
| 208 | + except TimeoutError as e: |
| 209 | + print(f"\nError: {e}") |
| 210 | + print("Please check your internet connection and try again.") |
| 211 | + except ConnectionError as e: |
| 212 | + print(f"\nError: {e}") |
| 213 | + except Exception as e: |
| 214 | + print(f"\nUnexpected error: {e}") |
| 215 | + raise |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + main() |
0 commit comments