diff --git a/.github/workflows/cd.yaml b/.github/workflows/cd.yaml index c531318..085eb09 100644 --- a/.github/workflows/cd.yaml +++ b/.github/workflows/cd.yaml @@ -1,53 +1,53 @@ name: CD on: - push: - tags: - - '*' + release: + types: [published] jobs: - build: - name: Build distribution + release-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - name: Install pypa/build - run: >- - python3 -m - pip install - build - --user - - name: Build a binary wheel and a source tarball - run: python3 -m build - - name: Store the distribution packages - uses: actions/upload-artifact@v3 - with: - name: python-package-distributions - path: dist/ - - publish-to-pypi: - name: >- - Publish Python distribution 📦 to PyPI - if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: - - build + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + # NOTE: put your own distribution build steps here. + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: runs-on: ubuntu-latest + + needs: + - release-build + + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. environment: name: pypi url: https://pypi.org/p/wpodnet-pytorch - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: dist/ - - name: Publish distribution to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@6f7e8d9c0b1a2c3d4e5f6a7b8c9d0e1f2a3b4c5d diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 37925aa..efd065c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,38 +1,28 @@ name: CI -on: [push] +on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - name: Install dependencies - run: | - pip install --upgrade pip - pip install -r requirements.txt - - name: Lint with Ruff - run: | - pip install ruff - ruff -- --format=github . - continue-on-error: true + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install --upgrade pip - pip install -r requirements.txt + python -m pip install --upgrade pip + pip install -e .[test] - name: Test with pytest run: | - pip install pytest pytest-cov pytest-mock - python -m pytest --junitxml=junit/test-results.xml --cov --cov-report=xml --cov-report=html \ No newline at end of file + pytest --junitxml=junit/test-results.xml --cov --cov-report=xml --cov-report=html diff --git a/predict.py b/predict.py index 6436f1a..c109c5b 100644 --- a/predict.py +++ b/predict.py @@ -1,89 +1,116 @@ import errno from argparse import ArgumentParser, ArgumentTypeError from pathlib import Path +from typing import List, Union import torch +from PIL import Image, UnidentifiedImageError -from wpodnet.backend import Predictor -from wpodnet.model import WPODNet -from wpodnet.stream import ImageStreamer +from wpodnet import Predictor, load_wpodnet_from_checkpoint -if __name__ == '__main__': + +def list_image_paths(p: Union[str, Path]) -> List[Path]: + """ + List all images in a directory. + + Args: + path (Union[str, Path]): The path to the directory containing images. + + Returns: + Generator[Image.Image]: A generator of PIL Image objects. + """ + p = Path(p) + if not p.is_dir(): + raise FileNotFoundError(errno.ENOTDIR, "No such directory", args.save_annotated) + + image_paths: List[Path] = [] + for f in p.glob("**/*"): + try: + with Image.open(f) as image: + image.verify() + image_paths.append(f) + except UnidentifiedImageError: + pass + return image_paths + + +if __name__ == "__main__": parser = ArgumentParser() + parser.add_argument("source", type=str, help="the path to the image") parser.add_argument( - 'source', - type=str, - help='the path to the image' + "-w", "--weight", type=str, required=True, help="the path to the model weight" ) parser.add_argument( - '-w', '--weight', - type=str, - required=True, - help='the path to the model weight' - ) - parser.add_argument( - '--scale', + "--scale", type=float, default=1.0, - help='adjust the scaling ratio. default to 1.0.' + help="adjust the scaling ratio. default to 1.0.", ) parser.add_argument( - '--save-annotated', + "--save-annotated", type=str, - help='save the annotated image at the given folder' + help="save the annotated image at the given folder", ) parser.add_argument( - '--save-warped', - type=str, - help='save the warped image at the given folder' + "--save-warped", type=str, help="save the warped image at the given folder" ) args = parser.parse_args() if args.scale <= 0.0: - raise ArgumentTypeError(message='scale must be greater than 0.0') + raise ArgumentTypeError(message="scale must be greater than 0.0") if args.save_annotated is not None: save_annotated = Path(args.save_annotated) if not save_annotated.is_dir(): - raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_annotated) + raise FileNotFoundError( + errno.ENOTDIR, "No such directory", args.save_annotated + ) else: save_annotated = None if args.save_warped is not None: save_warped = Path(args.save_warped) if not save_warped.is_dir(): - raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_warped) + raise FileNotFoundError( + errno.ENOTDIR, "No such directory", args.save_warped + ) else: save_warped = None # Prepare for the model - device = 'cuda' if torch.cuda.is_available() else 'cpu' - model = WPODNet() - model.to(device) - - checkpoint = torch.load(args.weight) - model.load_state_dict(checkpoint) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = load_wpodnet_from_checkpoint(args.weight).to(device) predictor = Predictor(model) - streamer = ImageStreamer(args.source) - for i, image in enumerate(streamer): + source = Path(args.source) + if source.is_file(): + image_paths = [source] + elif source.is_dir(): + image_paths = list_image_paths(source) + else: + raise FileNotFoundError(errno.ENOENT, "No such file or directory", args.source) + + for i, image_path in enumerate(image_paths): + image = Image.open(image_path) prediction = predictor.predict(image, scaling_ratio=args.scale) - print(f'Prediction #{i}') - print(' bounds', prediction.bounds.tolist()) - print(' confidence', prediction.confidence) + print(f"Prediction #{i}") + print(" bounds", prediction.bounds) + print(" confidence", prediction.confidence) if save_annotated: annotated_path = save_annotated / Path(image.filename).name - annotated = prediction.annotate() - annotated.save(annotated_path) - print(f'Saved the annotated image at {annotated_path}') + + canvas = image.copy() + prediction.annotate(canvas, outline="red") + canvas.save(annotated_path) + print(f"Saved the annotated image at {annotated_path}") if save_warped: warped_path = save_warped / Path(image.filename).name - warped = prediction.warp() + warped = prediction.warp(image) warped.save(warped_path) - print(f'Saved the warped image at {warped_path}') + print(f"Saved the warped image at {warped_path}") print() diff --git a/pyproject.toml b/pyproject.toml index a9b9755..98d924a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [build-system] -requires = ["setuptools"] +requires = ["setuptools>=64", "setuptools-scm[toml]>=8"] build-backend = "setuptools.build_meta" [project] name = "wpodnet-pytorch" -dynamic = ["dependencies", "version"] +dynamic = ["version"] description = "The implementation of ECCV 2018 paper \"License Plate Detection and Recognition in Unconstrained Scenarios\" in PyTorch" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.8" keywords = [ "python", "ai", @@ -41,18 +41,66 @@ classifiers = [ "Operating System :: MacOS", "Operating System :: Microsoft :: Windows", ] +dependencies = ["numpy", "Pillow", "torch", "torchvision"] [tool.setuptools] -packages = { find = { where = ["."], include = ["wpodnet", "wpodnet.*"] } } +packages.find.include = ["wpodnet"] -[tool.setuptools.dynamic] -dependencies = { file = "requirements.txt" } -version = { attr = "wpodnet.__version__" } +[tool.setuptools_scm] +fallback_version = "0.1.0" [project.optional-dependencies] -dev = [ - "pytest" -] +test = ["pytest", "pytest-cov", "pooch"] [project.urls] -"Source" = "https://github.com/Pandede/WPODNet-Pytorch" +Source = "https://github.com/Pandede/WPODNet-Pytorch" + +[tool.ruff] +target-version = "py38" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle (error) + "F", # pyflakes + "B", # bugbear + "B9", + "C4", # flake8-comprehensions + "SIM", # flake8-simplify + "I", # isort + "UP", # pyupgrade + "PIE", # flake8-pie + "PGH", # pygrep-hooks + "PYI", # flake8-pyi + "RUF", +] +ignore = [ + # only relevant if you run a script with `python -0`, + # which seems unlikely for any of the scripts in this repo + "B011", + # Leave it to the formatter to split long lines and + # the judgement of all of us. + "E501", +] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["D"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.run] +source = ["wpodnet"] + +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "return NotImplemented", + "pragma: no cover", + "pragma: deprecated", + "pragma: develop", + "pass", +] diff --git a/tests/integration/wpodnet/backend/test_predictor.py b/tests/integration/wpodnet/backend/test_predictor.py new file mode 100644 index 0000000..b113c9d --- /dev/null +++ b/tests/integration/wpodnet/backend/test_predictor.py @@ -0,0 +1,38 @@ +from typing import List, Tuple + +import pytest +from PIL import Image + +from wpodnet import Predictor, WPODNet + + +@pytest.fixture +def predictor(wpodnet: WPODNet) -> Predictor: + return Predictor(wpodnet) + + +class TestPredictor: + @pytest.mark.parametrize( + "image_path, bounds", + [ + ( + "docs/sample/original/03009.jpg", + [(1384, 682), (1517, 614), (1502, 685), (1369, 753)], + ), + ( + "docs/sample/original/03016.jpg", + [(567, 375), (643, 368), (643, 393), (567, 400)], + ), + ( + "docs/sample/original/03025.jpg", + [(94, 162), (162, 171), (160, 186), (92, 177)], + ), + ], + ) + def test_predict( + self, predictor: Predictor, image_path: str, bounds: List[Tuple[int, int]] + ): + image = Image.open(image_path) + prediction = predictor.predict(image) + assert prediction.bounds == bounds + assert prediction.confidence >= 0.9 diff --git a/tests/integration/wpodnet/conftest.py b/tests/integration/wpodnet/conftest.py new file mode 100644 index 0000000..9b6a893 --- /dev/null +++ b/tests/integration/wpodnet/conftest.py @@ -0,0 +1,13 @@ +import pooch +import pytest + +from wpodnet import WPODNet, load_wpodnet_from_checkpoint + + +@pytest.fixture(scope="session") +def wpodnet() -> WPODNet: + checkpoint = pooch.retrieve( + "https://github.com/Pandede/WPODNet-Pytorch/releases/download/1.0.0/wpodnet.pth", + known_hash="ac9fded54614d01b3082dd4e3917a65d4720b77a3f468fa934dbd85c814d3d77", + ) + return load_wpodnet_from_checkpoint(checkpoint) diff --git a/tests/unit/wpodnet/backend/test_prediction.py b/tests/unit/wpodnet/backend/test_prediction.py new file mode 100644 index 0000000..a149af5 --- /dev/null +++ b/tests/unit/wpodnet/backend/test_prediction.py @@ -0,0 +1,17 @@ +import pytest + +from wpodnet import Prediction + + +class TestPrediction: + def test_validation(self): + # Valid + Prediction(bounds=[(1, 2), (3, 4), (5, 6), (7, 8)], confidence=0.95) + + # Invalid bounds length + with pytest.raises(ValueError): + Prediction(bounds=[(1, 2), (3, 4), (5, 6)], confidence=0.95) + + # Invalid confidence value + with pytest.raises(ValueError): + Prediction(bounds=[(1, 2), (3, 4), (5, 6), (7, 8)], confidence=1.01) diff --git a/tests/wpodnet/stream/test_image_streamer.py b/tests/wpodnet/stream/test_image_streamer.py deleted file mode 100644 index 6582066..0000000 --- a/tests/wpodnet/stream/test_image_streamer.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path - -import pytest -from PIL import Image - -from wpodnet.stream import ImageStreamer - - -@pytest.fixture -def image_folder(tmp_path: Path) -> Path: - exts = {'.jpeg', '.png', '.bmp'} - for ext in exts: - image = Image.new('RGB', (10, 10)) - image_path = (tmp_path / 'image').with_suffix(ext) - image.save(image_path) - return tmp_path - - -@pytest.mark.usefixtures('image_folder') -class TestImageStreamer: - def test_load_image_file(self, image_folder: Path): - for image_path in image_folder.glob('**/*'): - streamer = ImageStreamer(image_path) - images = list(streamer) - assert len(images) == 1 - assert f'.{images[0].format.lower()}' == image_path.suffix - - def test_load_image_folder(self, image_folder: Path): - streamer = ImageStreamer(image_folder) - images = list(streamer) - assert len(images) == 3 - - # Add a non-image file - (image_folder / 'text.doc').touch() - streamer = ImageStreamer(image_folder) - images = list(streamer) - assert len(images) == 3 - - def test_load_invalid_image(self, tmp_path: Path): - doc_file = tmp_path / 'image.doc' - - with pytest.raises(TypeError, match='Invalid path to images'): - list(ImageStreamer(doc_file)) diff --git a/wpodnet/__init__.py b/wpodnet/__init__.py index d4f346b..62df6e0 100644 --- a/wpodnet/__init__.py +++ b/wpodnet/__init__.py @@ -1,7 +1,29 @@ -__version__ = '1.0.3' +from pathlib import Path +from typing import Union + +import torch from .backend import Prediction, Predictor +from .model import WPODNet + + +def load_wpodnet_from_checkpoint(ckpt_path: Union[str, Path]) -> WPODNet: + """ + Load a pre-trained WPOD-NET model from a checkpoint file. + + Args: + ckpt_path (Union[str, Path]): The path to the checkpoint file. + + Returns: + WPODNet: The WPOD-NET model with pretrained weights loaded from the checkpoint. + """ + model = WPODNet() + + # Load the state dictionary from the checkpoint + checkpoint = torch.load(ckpt_path, weights_only=True) + model.load_state_dict(checkpoint) + + return model + -__all__ = [ - 'Prediction', 'Predictor' -] +__all__ = ["Prediction", "Predictor", "WPODNet", "load_wpodnet_from_checkpoint"] diff --git a/wpodnet/backend.py b/wpodnet/backend.py index c251944..6cea2cd 100644 --- a/wpodnet/backend.py +++ b/wpodnet/backend.py @@ -1,74 +1,124 @@ -from typing import List, Tuple +from dataclasses import dataclass +from typing import List, Optional, Tuple import numpy as np import torch from PIL import Image, ImageDraw -from torchvision.transforms.functional import (_get_perspective_coeffs, - to_tensor) +from torchvision.transforms.functional import _get_perspective_coeffs, to_tensor from .model import WPODNet +@dataclass(frozen=True) class Prediction: - def __init__(self, image: Image.Image, bounds: np.ndarray, confidence: float): - self.image = image - self.bounds = bounds - self.confidence = confidence - - def _get_perspective_coeffs(self, width: int, height: int) -> List[float]: - # Get the perspective matrix - src_points = self.bounds.tolist() - dst_points = [[0, 0], [width, 0], [width, height], [0, height]] - return _get_perspective_coeffs(src_points, dst_points) - - def annotate(self, outline: str = 'red', width: int = 3) -> Image.Image: - canvas = self.image.copy() + """ + The prediction result from WPODNet. + + Attributes: + bounds (List[Tuple[int, int]]): The bounding coordinates of the detected license plate. Must be a list of 4 points (x, y). + confidence (float): The confidence score of the detection. Must be between 0.0 and 1.0. + """ + + bounds: List[Tuple[int, int]] + confidence: float + + def __post_init__(self): + if len(self.bounds) != 4: + raise ValueError( + f"expected bounds to have 4 points, got {len(self.bounds)} points" + ) + if self.confidence < 0 or self.confidence > 1: + raise ValueError( + f"confidence must be between 0.0 and 1.0, got {self.confidence}" + ) + + def annotate( + self, + canvas: Image.Image, + fill: Optional[str] = None, + outline: Optional[str] = None, + width: int = 1, + ) -> None: # pragma: no cover + """ + Annotates the image with the bounding polygon. + + Args: + canvas (PIL.Image.Image): The image to be annotated. + fill (Optional[str]): The fill color for the polygon. Defaults to None. + outline (Optional[str]): The outline color for the polygon. Defaults to None. + width (int): The width of the outline. Defaults to 1. + + Note: + The arguments `fill`, `outline`, and `width` are passed to the `ImageDraw.Draw.polygon` method. + See https://pillow.readthedocs.io/en/stable/reference/ImageDraw.html#PIL.ImageDraw.ImageDraw.polygon. + """ drawer = ImageDraw.Draw(canvas) - drawer.polygon( - [(x, y) for x, y in self.bounds], - outline=outline, - width=width + drawer.polygon(self.bounds, fill=fill, outline=outline, width=width) + + def warp(self, canvas: Image.Image) -> Image.Image: # pragma: no cover + """ + Warps the image with perspective based on the bounding polygon. + + Args: + canvas (PIL.Image.Image): The image to be warped. + + Returns: + PIL.Image.Image: The warped image. + """ + coeffs = _get_perspective_coeffs( + startpoints=self.bounds, + endpoints=[ + (0, 0), + (canvas.width, 0), + (canvas.width, canvas.height), + (0, canvas.height), + ], ) - return canvas + return canvas.transform( + (canvas.width, canvas.height), Image.Transform.PERSPECTIVE, coeffs + ) + - def warp(self, width: int = 208, height: int = 60) -> Image.Image: - # Get the perspective matrix - coeffs = self._get_perspective_coeffs(width, height) - warped = self.image.transform((width, height), Image.PERSPECTIVE, coeffs) - return warped +Q = np.array( + [ + [-0.5, 0.5, 0.5, -0.5], + [-0.5, -0.5, 0.5, 0.5], + [1.0, 1.0, 1.0, 1.0], + ] +) class Predictor: - _q = np.array([ - [-.5, .5, .5, -.5], - [-.5, -.5, .5, .5], - [1., 1., 1., 1.] - ]) - _scaling_const = 7.75 - _stride = 16 - - def __init__(self, wpodnet: WPODNet): + """A wrapper class for WPODNet to make predictions.""" + + def __init__(self, wpodnet: WPODNet) -> None: + """ + Args: + wpodnet (WPODNet): The WPODNet model to use for prediction. + """ self.wpodnet = wpodnet self.wpodnet.eval() - def _resize_to_fixed_ratio(self, image: Image.Image, dim_min: int, dim_max: int) -> Image.Image: + def _resize_to_fixed_ratio( + self, image: Image.Image, dim_min: int, dim_max: int + ) -> Image.Image: h, w = image.height, image.width wh_ratio = max(h, w) / min(h, w) side = int(wh_ratio * dim_min) - bound_dim = min(side + side % self._stride, dim_max) + bound_dim = min(side + side % self.wpodnet.stride, dim_max) factor = bound_dim / max(h, w) reg_w, reg_h = int(w * factor), int(h * factor) - # Ensure the both width and height are the multiply of `self._stride` - reg_w_mod = reg_w % self._stride + # Ensure the both width and height are the multiply of `self.wpodnet.stride` + reg_w_mod = reg_w % self.wpodnet.stride if reg_w_mod > 0: - reg_w += self._stride - reg_w_mod + reg_w += self.wpodnet.stride - reg_w_mod - reg_h_mod = reg_h % self._stride + reg_h_mod = reg_h % self.wpodnet.stride if reg_h_mod > 0: - reg_h += self._stride - reg_h % self._stride + reg_h += self.wpodnet.stride - reg_h_mod return image.resize((reg_w, reg_h)) @@ -83,7 +133,7 @@ def _inference(self, image: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: # Convert to squeezed numpy array # grid_w: The number of anchors in row # grid_h: The number of anchors in column - probs = np.squeeze(probs.cpu().numpy())[0] # (grid_h, grid_w) + probs = np.squeeze(probs.cpu().numpy())[0] # (grid_h, grid_w) affines = np.squeeze(affines.cpu().numpy()) # (6, grid_h, grid_w) return probs, affines @@ -91,7 +141,13 @@ def _inference(self, image: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: def _get_max_anchor(self, probs: np.ndarray) -> Tuple[int, int]: return np.unravel_index(probs.argmax(), probs.shape) - def _get_bounds(self, affines: np.ndarray, anchor_y: int, anchor_x: int, scaling_ratio: float = 1.0) -> np.ndarray: + def _get_bounds( + self, + affines: np.ndarray, + anchor_y: int, + anchor_x: int, + scaling_ratio: float = 1.0, + ) -> np.ndarray: # Compute theta theta = affines[:, anchor_y, anchor_x] theta = theta.reshape((2, 3)) @@ -99,16 +155,34 @@ def _get_bounds(self, affines: np.ndarray, anchor_y: int, anchor_x: int, scaling theta[1, 1] = max(theta[1, 1], 0.0) # Convert theta into the bounding polygon - bounds = np.matmul(theta, self._q) * self._scaling_const * scaling_ratio + bounds = np.matmul(theta, Q) * self.wpodnet.scale_factor * scaling_ratio # Normalize the bounds _, grid_h, grid_w = affines.shape - bounds[0] = (bounds[0] + anchor_x + .5) / grid_w - bounds[1] = (bounds[1] + anchor_y + .5) / grid_h + bounds[0] = (bounds[0] + anchor_x + 0.5) / grid_w + bounds[1] = (bounds[1] + anchor_y + 0.5) / grid_h return np.transpose(bounds) - def predict(self, image: Image.Image, scaling_ratio: float = 1.0, dim_min: int = 288, dim_max: int = 608) -> Prediction: + def predict( + self, + image: Image.Image, + scaling_ratio: float = 1.0, + dim_min: int = 512, + dim_max: int = 768, + ) -> Prediction: + """ + Detect license plate in the image. + + Args: + image (Image.Image): The image to be detected. + scaling_ratio (float): The scaling ratio of the resulting bounding polygon. Default to 1.0. + dim_min (int): The minimum dimension of the resized image. Default to 512 + dim_max (int): The maximum dimension of the resized image. Default to 768 + + Returns: + Prediction: The prediction result with highest confidence. + """ orig_h, orig_w = image.height, image.width # Resize the image to fixed ratio @@ -131,7 +205,6 @@ def predict(self, image: Image.Image, scaling_ratio: float = 1.0, dim_min: int = bounds[:, 1] *= orig_h return Prediction( - image=image, - bounds=bounds.astype(np.int32), - confidence=max_prob.item() + bounds=[(x, y) for x, y in np.int32(bounds).tolist()], + confidence=max_prob.item(), ) diff --git a/wpodnet/model.py b/wpodnet/model.py index a7b8af2..05f6e61 100644 --- a/wpodnet/model.py +++ b/wpodnet/model.py @@ -1,13 +1,14 @@ import torch -import torch.nn as nn -class BasicConvBlock(nn.Module): +class BasicConvBlock(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int): - super(BasicConvBlock, self).__init__() - self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - self.bn_layer = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.001) - self.act_layer = nn.ReLU(inplace=True) + super().__init__() + self.conv_layer = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=1 + ) + self.bn_layer = torch.nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.001) + self.act_layer = torch.nn.ReLU(inplace=True) def forward(self, x): x = self.conv_layer(x) @@ -15,13 +16,13 @@ def forward(self, x): return self.act_layer(x) -class ResBlock(nn.Module): +class ResBlock(torch.nn.Module): def __init__(self, channels: int): - super(ResBlock, self).__init__() + super().__init__() self.conv_block = BasicConvBlock(channels, channels) - self.sec_layer = nn.Conv2d(channels, channels, kernel_size=3, padding=1) - self.bn_layer = nn.BatchNorm2d(channels, momentum=0.99, eps=0.001) - self.act_layer = nn.ReLU(inplace=True) + self.sec_layer = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn_layer = torch.nn.BatchNorm2d(channels, momentum=0.99, eps=0.001) + self.act_layer = torch.nn.ReLU(inplace=True) def forward(self, x): h = self.conv_block(x) @@ -30,35 +31,45 @@ def forward(self, x): return self.act_layer(x + h) -class WPODNet(nn.Module): +class WPODNet(torch.nn.Module): + """ + WPODNet in PyTorch. + + The original architecture is built in Keras: https://github.com/sergiomsilva/alpr-unconstrained/blob/master/create-model.py + """ + + # https://github.com/sergiomsilva/alpr-unconstrained/blob/master/src/keras_utils.py#L43-L44 + stride = 16 # net_stride + scale_factor = 7.75 # side + def __init__(self): - super(WPODNet, self).__init__() - self.backbone = nn.Sequential( + super().__init__() + self.backbone = torch.nn.Sequential( BasicConvBlock(3, 16), BasicConvBlock(16, 16), - nn.MaxPool2d(2), + torch.nn.MaxPool2d(2), BasicConvBlock(16, 32), ResBlock(32), - nn.MaxPool2d(2), + torch.nn.MaxPool2d(2), BasicConvBlock(32, 64), ResBlock(64), ResBlock(64), - nn.MaxPool2d(2), + torch.nn.MaxPool2d(2), BasicConvBlock(64, 64), ResBlock(64), ResBlock(64), - nn.MaxPool2d(2), + torch.nn.MaxPool2d(2), BasicConvBlock(64, 128), ResBlock(128), ResBlock(128), ResBlock(128), - ResBlock(128) + ResBlock(128), ) - self.prob_layer = nn.Conv2d(128, 2, kernel_size=3, padding=1) - self.bbox_layer = nn.Conv2d(128, 6, kernel_size=3, padding=1) + self.prob_layer = torch.nn.Conv2d(128, 2, kernel_size=3, padding=1) + self.bbox_layer = torch.nn.Conv2d(128, 6, kernel_size=3, padding=1) # Registry a dummy tensor for retrieve the attached device - self.register_buffer('dummy', torch.Tensor(), persistent=False) + self.register_buffer("dummy", torch.Tensor(), persistent=False) @property def device(self) -> torch.device: diff --git a/wpodnet/stream.py b/wpodnet/stream.py deleted file mode 100644 index dd69bfc..0000000 --- a/wpodnet/stream.py +++ /dev/null @@ -1,36 +0,0 @@ -from pathlib import Path -from typing import Generator, Union - -from PIL import Image - - -class ImageStreamer: - def __init__(self, image_or_folder: Union[str, Path]): - path = Path(image_or_folder) - self.generator = self._get_image_generator(path) - - def _get_image_generator(self, path: Path) -> Generator[Image.Image, None, None]: - if path.is_file(): - image_paths = [path] if self._is_image_file(path) else [] - elif path.is_dir(): - image_paths = [ - p - for p in path.rglob('**/*') - if self._is_image_file(p) - ] - else: - raise TypeError(f'Invalid path to images {path}') - - for p in image_paths: - yield Image.open(p) - - def _is_image_file(self, path: Path) -> bool: - try: - image = Image.open(path) - image.verify() - return True - except Exception: - return False - - def __iter__(self): - return self.generator