Skip to content

Commit 83469f9

Browse files
committed
Add Vision Transformer demo for image classification (Fixes #13372)
1 parent 3cea941 commit 83469f9

File tree

3 files changed

+223
-0
lines changed

3 files changed

+223
-0
lines changed

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
* [Mean Threshold](computer_vision/mean_threshold.py)
140140
* [Mosaic Augmentation](computer_vision/mosaic_augmentation.py)
141141
* [Pooling Functions](computer_vision/pooling_functions.py)
142+
* [Vision Transformer Demo](computer_vision/vision_transformer_demo.py)
142143

143144
## Conversions
144145
* [Astronomical Length Scale Conversion](conversions/astronomical_length_scale_conversion.py)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
Vision Transformer (ViT) Image Classification Demo
3+
4+
This module demonstrates how to use a pre-trained Vision Transformer (ViT) model
5+
from Hugging Face for image classification tasks.
6+
7+
Vision Transformers apply the transformer architecture (originally designed for NLP)
8+
to computer vision by splitting images into patches and processing them with
9+
self-attention mechanisms.
10+
11+
Requirements:
12+
- torch
13+
- transformers
14+
- Pillow (PIL)
15+
- requests
16+
17+
Resources:
18+
- Paper: https://arxiv.org/abs/2010.11929
19+
- Hugging Face: https://huggingface.co/docs/transformers/model_doc/vit
20+
21+
Example Usage:
22+
from computer_vision.vision_transformer_demo import classify_image
23+
24+
# Classify an image from URL
25+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
26+
result = classify_image(url)
27+
print(f"Predicted: {result['label']} (confidence: {result['score']:.2%})")
28+
29+
# Classify a local image
30+
result = classify_image("path/to/image.jpg", top_k=3)
31+
for pred in result['top_k_predictions']:
32+
print(f"{pred['label']}: {pred['score']:.2%}")
33+
"""
34+
35+
from __future__ import annotations
36+
37+
import sys
38+
from io import BytesIO
39+
from pathlib import Path
40+
from typing import Any
41+
42+
try:
43+
import requests
44+
import torch
45+
from PIL import Image
46+
from transformers import ViTForImageClassification, ViTImageProcessor
47+
except ImportError as e:
48+
print(f"Error: Missing required dependency: {e.name}")
49+
print("Install dependencies: pip install torch transformers pillow requests")
50+
sys.exit(1)
51+
52+
53+
def load_image(image_source: str | Path, timeout: int = 10) -> Image.Image:
54+
"""
55+
Load an image from a URL or local file path.
56+
57+
Args:
58+
image_source: URL string or Path object to the image
59+
timeout: Network timeout in seconds (default: 10)
60+
61+
Returns:
62+
PIL Image object
63+
64+
Raises:
65+
TimeoutError: If request times out
66+
ConnectionError: If URL is unreachable
67+
FileNotFoundError: If local file doesn't exist
68+
IOError: If image cannot be opened
69+
70+
Examples:
71+
>>> # Test with non-existent file
72+
>>> try:
73+
... load_image("nonexistent_file.jpg")
74+
... except FileNotFoundError:
75+
... print("File not found")
76+
File not found
77+
"""
78+
if isinstance(image_source, (str, Path)) and str(image_source).startswith(
79+
("http://", "https://")
80+
):
81+
try:
82+
response = requests.get(str(image_source), timeout=timeout)
83+
response.raise_for_status()
84+
return Image.open(BytesIO(response.content)).convert("RGB")
85+
except requests.exceptions.Timeout:
86+
msg = (
87+
f"Request timed out after {timeout} seconds. "
88+
"Try increasing the timeout parameter."
89+
)
90+
raise TimeoutError(msg)
91+
except requests.exceptions.RequestException as e:
92+
msg = f"Failed to download image from URL: {e}"
93+
raise ConnectionError(msg) from e
94+
else:
95+
# Load from local file
96+
file_path = Path(image_source)
97+
if not file_path.exists():
98+
msg = f"Image file not found: {file_path}"
99+
raise FileNotFoundError(msg)
100+
return Image.open(file_path).convert("RGB")
101+
102+
103+
def classify_image(
104+
image_source: str | Path,
105+
model_name: str = "google/vit-base-patch16-224",
106+
top_k: int = 1,
107+
) -> dict[str, Any]:
108+
"""
109+
Classify an image using a Vision Transformer model.
110+
111+
Args:
112+
image_source: URL or local path to the image
113+
model_name: Hugging Face model identifier (default: google/vit-base-patch16-224)
114+
top_k: Number of top predictions to return (default: 1)
115+
116+
Returns:
117+
Dictionary containing:
118+
- label: Predicted class label
119+
- score: Confidence score (0-1)
120+
- top_k_predictions: List of top-k predictions (if top_k > 1)
121+
122+
Raises:
123+
ValueError: If top_k is less than 1
124+
FileNotFoundError: If image file doesn't exist
125+
ConnectionError: If unable to download from URL
126+
127+
Examples:
128+
>>> # Test parameter validation
129+
>>> try:
130+
... classify_image("test.jpg", top_k=0)
131+
... except ValueError as e:
132+
... print("Invalid top_k")
133+
Invalid top_k
134+
"""
135+
if top_k < 1:
136+
raise ValueError("top_k must be at least 1")
137+
# Load image
138+
image = load_image(image_source)
139+
140+
# Load pre-trained model and processor
141+
# Using context manager pattern for better resource management
142+
processor = ViTImageProcessor.from_pretrained(model_name)
143+
model = ViTForImageClassification.from_pretrained(model_name)
144+
145+
# Preprocess image
146+
inputs = processor(images=image, return_tensors="pt")
147+
148+
# Perform inference
149+
with torch.no_grad(): # Disable gradient calculation for inference
150+
outputs = model(**inputs)
151+
logits = outputs.logits
152+
153+
# Get predictions
154+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
155+
top_k_probs, top_k_indices = torch.topk(probabilities, k=top_k, dim=-1)
156+
157+
# Format results
158+
predictions = []
159+
for prob, idx in zip(top_k_probs[0], top_k_indices[0]):
160+
predictions.append(
161+
{"label": model.config.id2label[idx.item()], "score": prob.item()}
162+
)
163+
164+
result = {
165+
"label": predictions[0]["label"],
166+
"score": predictions[0]["score"],
167+
"top_k_predictions": predictions if top_k > 1 else None,
168+
}
169+
170+
return result
171+
172+
173+
def main() -> None:
174+
"""
175+
Main function demonstrating Vision Transformer usage.
176+
177+
Downloads a sample image and performs classification.
178+
"""
179+
print("Vision Transformer (ViT) Image Classification Demo")
180+
print("=" * 60)
181+
182+
# Sample image URL (two cats on a couch from COCO dataset)
183+
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
184+
185+
print(f"\nLoading image from: {image_url}")
186+
187+
try:
188+
# Get top-3 predictions
189+
result = classify_image(image_url, top_k=3)
190+
191+
print(f"\n{'Prediction Results':^60}")
192+
print("-" * 60)
193+
print(f"Top Prediction: {result['label']}")
194+
print(f"Confidence: {result['score']:.2%}")
195+
196+
if result["top_k_predictions"]:
197+
print(f"\n{'Top 3 Predictions':^60}")
198+
print("-" * 60)
199+
for i, pred in enumerate(result["top_k_predictions"], 1):
200+
print(f"{i}. {pred['label']:<40} {pred['score']:>6.2%}")
201+
202+
# Example with local image (commented out)
203+
print("\n" + "=" * 60)
204+
print("To classify a local image, use:")
205+
print(' result = classify_image("path/to/your/image.jpg")')
206+
print(" print(f\"Predicted: {result['label']}\")")
207+
208+
except TimeoutError as e:
209+
print(f"\nError: {e}")
210+
print("Please check your internet connection and try again.")
211+
except ConnectionError as e:
212+
print(f"\nError: {e}")
213+
except Exception as e:
214+
print(f"\nUnexpected error: {e}")
215+
raise
216+
217+
218+
if __name__ == "__main__":
219+
main()

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ dependencies = [
2020
"opencv-python>=4.10.0.84",
2121
"pandas>=2.2.3",
2222
"pillow>=11.3",
23+
"requests>=2.31.0",
2324
"rich>=13.9.4",
2425
"scikit-learn>=1.5.2",
2526
"scipy>=1.16.2",
2627
"sphinx-pyproject>=0.3",
2728
"statsmodels>=0.14.4",
2829
"sympy>=1.13.3",
30+
"torch>=2.0.0",
31+
"transformers>=4.30.0",
2932
"tweepy>=4.14",
3033
"typing-extensions>=4.12.2",
3134
"xgboost>=2.1.3",

0 commit comments

Comments
 (0)