diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d622fec..0bbc0bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,6 +13,7 @@ jobs: strategy: matrix: os: [ubuntu-22.04, windows-latest, macos-14] + fail-fast: false runs-on: ${{ matrix.os }} timeout-minutes: 15 @@ -44,6 +45,14 @@ jobs: version: 1.4.309.0 cache: true + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + version: "0.9.15" + + - name: Set up Python + run: uv sync --dev --python 3.12 + - name: Configure (Linux) if: matrix.os == 'ubuntu-22.04' run: > diff --git a/CMakeLists.txt b/CMakeLists.txt index 66a5860..f579677 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.28) -project(vision.cpp VERSION 0.2.0 LANGUAGES CXX) +project(vision.cpp VERSION 0.3.0 LANGUAGES CXX) option(BUILD_SHARED_LIBS "Build shared libraries instead of static libraries" ON) option(VISP_VULKAN "Enable Vulkan support" OFF) @@ -145,6 +145,8 @@ if(VISP_CI OR VISP_DEV) set_target_properties(vision-cli PROPERTIES INSTALL_RPATH "\$ORIGIN/../${VISP_LIB_INSTALL_DIR}") endif() +install(DIRECTORY bindings/python DESTINATION . PATTERN "__pycache__" EXCLUDE) + include(CMakePackageConfigHelpers) configure_package_config_file( diff --git a/README.md b/README.md index c5f1a5a..6424408 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ ctest -C Release Some tests require a Python environment. It can be set up with [uv](https://docs.astral.sh/uv/): ```sh # Setup venv and install dependencies (once only) -uv sync +uv sync --dev # Run python tests uv run pytest diff --git a/bindings/python/visioncpp/__init__.py b/bindings/python/visioncpp/__init__.py new file mode 100644 index 0000000..a3fcb76 --- /dev/null +++ b/bindings/python/visioncpp/__init__.py @@ -0,0 +1,2 @@ +from ._lib import Error # noqa +from .vision import * # noqa diff --git a/bindings/python/visioncpp/_lib.py b/bindings/python/visioncpp/_lib.py new file mode 100644 index 0000000..4255eb7 --- /dev/null +++ b/bindings/python/visioncpp/_lib.py @@ -0,0 +1,191 @@ +import ctypes +import platform +from pathlib import Path +from ctypes import c_byte, c_char_p, c_void_p, c_int32, POINTER +from PIL import Image + + +class Error(Exception): + pass + + +def _image_format_to_string(format: int): + match format: + case 0: + return "RGBA" + case 3: + return "RGB" + case 4: + return "L" + case _: + raise ValueError(f"Unsupported image format: {format}") + + +def _image_mode_from_string(mode: str): + match mode: + case "RGBA": + return 0, 4 # visp::image_format, bytes per pixel + case "RGB": + return 3, 3 + case "L": + return 4, 1 + case _: + raise ValueError(f"Unsupported image mode: {mode}") + + +class ImageView(ctypes.Structure): + _fields_ = [ + ("width", c_int32), + ("height", c_int32), + ("stride", c_int32), + ("format", c_int32), + ("data", c_void_p), + ] + + @staticmethod + def from_bytes(width: int, height: int, stride: int, format: int, data: bytes): + ptr = (c_byte * len(data)).from_buffer_copy(data) + return ImageView(width, height, stride, format, ctypes.cast(ptr, ctypes.c_void_p)) + + @staticmethod + def from_pil_image(image): + assert isinstance(image, Image.Image), "Expected a PIL Image" + data = image.tobytes() + w, h = image.size + format, bpp = _image_mode_from_string(image.mode) + return ImageView.from_bytes(w, h, w * bpp, format, data) + + def to_pil_image(self): + mode = _image_format_to_string(self.format) + size = self.height * self.stride + data = memoryview((c_byte * size).from_address(self.data)) + return Image.frombytes(mode, (self.width, self.height), data, "raw", mode, self.stride) + + +class _ImageData(ctypes.Structure): + pass + + +class _Device(ctypes.Structure): + pass + + +class _Model(ctypes.Structure): + pass + + +ImageData = POINTER(_ImageData) +Device = POINTER(_Device) +Model = POINTER(_Model) + +Handle = ctypes._Pointer + + +def _load(): + cur_dir = Path(__file__).parent + system = platform.system().lower() + if system == "windows": + prefix = "" + suffix = ".dll" + libdir = "bin" + elif system == "darwin": + prefix = "lib" + suffix = ".dylib" + libdir = "lib" + else: # assume Linux / Unix + prefix = "lib" + suffix = ".so" + libdir = "lib" + libname = f"{prefix}visioncpp{suffix}" + paths = [ + cur_dir / libname, + cur_dir.parent.parent / libdir / libname, + cur_dir.parent.parent.parent / "build" / libdir / libname, + cur_dir.parent.parent.parent / "build" / libdir / "Release" / libname, + ] + error = f"Library {libname} not found in any of the following paths: {paths}" + for path in paths: + if path.exists(): + try: + lib = ctypes.CDLL(str(path)) + return lib, path + except OSError as e: + error = e + continue + raise OSError(f"Could not load vision.cpp library: {error}") + + +def init(): + lib, path = _load() + + lib.visp_get_last_error.restype = c_char_p + + lib.visp_backend_load_all.argtypes = [c_char_p] + lib.visp_backend_load_all.restype = c_int32 + + lib.visp_image_destroy.argtypes = [ImageData] + lib.visp_image_destroy.restype = None + + lib.visp_device_init.argtypes = [c_int32, POINTER(Device)] + lib.visp_device_init.restype = c_int32 + + lib.visp_device_destroy.argtypes = [Device] + lib.visp_device_destroy.restype = None + + lib.visp_device_type.argtypes = [Device] + lib.visp_device_type.restype = c_int32 + + lib.visp_device_name.argtypes = [Device] + lib.visp_device_name.restype = c_char_p + + lib.visp_device_description.argtypes = [Device] + lib.visp_device_description.restype = c_char_p + + lib.visp_model_detect_family.argtypes = [c_char_p, POINTER(c_int32)] + lib.visp_model_detect_family.restype = c_int32 + + lib.visp_model_load.argtypes = [c_char_p, Device, c_int32, POINTER(Model)] + lib.visp_model_load.restype = c_int32 + + lib.visp_model_destroy.argtypes = [Model, c_int32] + lib.visp_model_destroy.restype = None + + lib.visp_model_compute.argtypes = [ + Model, + c_int32, + POINTER(ImageView), + c_int32, + POINTER(c_int32), + c_int32, + POINTER(ImageView), + POINTER(ImageData), + ] + lib.visp_model_compute.restype = c_int32 + + # On Linux, libvisioncpp might be in lib/ and ggml backends in bin/ + if path.parent.name == "lib": + bin_dir = path.parent.parent / "bin" + if bin_dir.exists(): + lib.visp_backend_load_all(str(bin_dir).encode()) + + return lib + + +_lib: ctypes.CDLL | None = None + + +def get_lib() -> ctypes.CDLL: + global _lib + if _lib is None: + _lib = init() + return _lib + + +def check(return_value: int): + if return_value == 0: + assert _lib is not None, "Library not initialized" + raise Error(_lib.visp_get_last_error().decode()) + + +def path_to_char_p(p: str | Path): + return str(p).encode() diff --git a/bindings/python/visioncpp/vision.py b/bindings/python/visioncpp/vision.py new file mode 100644 index 0000000..87d78e4 --- /dev/null +++ b/bindings/python/visioncpp/vision.py @@ -0,0 +1,145 @@ +from ctypes import CDLL, byref, c_int32 +from enum import Enum +from pathlib import Path +from typing import NamedTuple, Sequence +import PIL.Image + +from . import _lib as lib +from ._lib import get_lib, check + + +class ImageFormat(Enum): + rgba_u8 = 0 + bgra_u8 = 1 + argb_u8 = 2 + rgb_u8 = 3 + alpha_u8 = 4 + + rgba_f32 = 5 + rgb_f32 = 6 + alpha_f32 = 7 + + +class ImageRef(NamedTuple): + width: int + height: int + stride: int + format: ImageFormat + data: bytes + + +Image = ImageRef | PIL.Image.Image + + +class Backend(Enum): + auto = 0 + cpu = 1 + gpu = 2 + + vulkan = gpu | 1 << 8 + + @property + def is_cpu(self): + return self.value & 0xFF00 == Backend.cpu.value + + @property + def is_gpu(self): + return self.value & 0xFF00 == Backend.gpu.value + + +class Device: + @staticmethod + def init(backend: Backend = Backend.auto): + api = get_lib() + handle = lib.Device() + check(api.visp_device_init(backend.value, byref(handle))) + return Device(api, handle) + + @property + def type(self) -> Backend: + return Backend(self._api.visp_device_type(self._handle)) + + @property + def name(self) -> str: + return self._api.visp_device_name(self._handle).decode() + + @property + def description(self) -> str: + return self._api.visp_device_description(self._handle).decode() + + def __init__(self, api: CDLL, handle: lib.Handle): + self._api = api + self._handle = handle + + def __del__(self): + self._api.visp_device_destroy(self._handle) + + +class Arch(Enum): + sam = 0 + birefnet = 1 + depth_anything = 2 + migan = 3 + esrgan = 4 + unknown = 5 + + +class Model: + @classmethod + def load(cls, path: str | Path, device: Device, arch=Arch.unknown): + api = get_lib() + handle = lib.Model() + path_str = lib.path_to_char_p(path) + if arch is Arch.unknown: + arch_v = c_int32() + check(api.visp_model_detect_family(path_str, byref(arch_v))) + arch = Arch(arch_v.value) + else: + arch_v = arch.value + + check(api.visp_model_load(path_str, device._handle, arch_v, byref(handle))) + return cls(api, handle, arch) + + def compute(self, *images: Image, args: Sequence[int] | None = None): + if args is None: + args = [] + + in_views = [_img_view(i) for i in images] + in_views_array = (lib.ImageView * len(in_views))(*in_views) + args_array = (lib.c_int32 * len(args))(*args) + out_view = lib.ImageView() + out_data = lib.ImageData() + check( + self._api.visp_model_compute( + self._handle, + self.arch.value, + in_views_array, + len(in_views_array), + args_array, + len(args_array), + byref(out_view), + byref(out_data), + ) + ) + try: + result = lib.ImageView.to_pil_image(out_view) + finally: + self._api.visp_image_destroy(out_data) + return result + + def __init__(self, api: CDLL, handle: lib.Handle, arch: Arch): + self.arch = arch + self._api = api + self._handle = handle + + def __del__(self): + self._api.visp_model_destroy(self._handle, self.arch.value) + + +def _img_view(i: Image) -> lib.ImageView: + if isinstance(i, PIL.Image.Image): + return lib.ImageView.from_pil_image(i) + elif isinstance(i, ImageRef): + return lib.ImageView.from_bytes(i.width, i.height, i.stride, i.format.value, i.data) + else: + raise TypeError("Expected a PIL Image or ImageRef") diff --git a/include/visp/vision.h b/include/visp/vision.h index 1e22096..7328231 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -81,6 +81,21 @@ namespace visp { +// Supported model architectures + +enum class model_family { + sam = 0, + birefnet, + depth_anything, + migan, + esrgan, + + count +}; + +VISP_API model_family model_detect_family(model_file const&); + +// // SWIN v1 - vision transformer for feature extraction constexpr int swin_n_layers = 4; @@ -104,6 +119,7 @@ VISP_API swin_params swin_detect_params(model_file const&); VISP_API swin_buffers swin_precompute(model_ref, i32x2 image_extent, swin_params const&); VISP_API swin_result swin_encode(model_ref, tensor image, swin_params const&); +// // DINO v2 - vision transformer for feature extraction struct dino_params { diff --git a/pyproject.toml b/pyproject.toml index 55f1326..d58814a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,20 +2,39 @@ dynamic = ["version"] name = "vision.cpp" requires-python = ">=3.12" -dependencies = [ +dependencies = ["pillow"] + +[dependency-groups] +dev = [ "torch", "torchvision", "timm", "pytest", + "numpy", "opencv-python", "ruff", "einops>=0.8.1", "spandrel>=0.4.1", "gguf>=0.17.1", + "boto3~=1.39.0", ] [tool.uv] package = false +required-environments = [ + "sys_platform == 'win32' and platform_machine == 'AMD64'", + "sys_platform == 'linux' and platform_machine == 'x86_64'", + "sys_platform == 'linux' and platform_machine == 'aarch64'", +] + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +[tool.uv.sources] +torch = [{ index = "pytorch-cpu" }] +torchvision = [{ index = "pytorch-cpu" }] [tool.ruff] target-version = "py312" diff --git a/scripts/convert.py b/scripts/convert.py index cc91d63..e2d6ddc 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -75,9 +75,9 @@ def convert_tensor_2d(self, tensor: Tensor): self.conv2d_weights.append(self._index) return tensor - def add_int32(self, name: str, value: int): - print("*", name, "=", value) - super().add_int32(name, value) + def add_int32(self, key: str, val: int): + print("*", key, "=", val) + super().add_int32(key, val) def set_tensor_layout(self, layout: TensorLayout): print("*", f"{self.arch}.tensor_data_layout", "=", layout.value) @@ -201,7 +201,7 @@ def convert_sam(input_filepath: Path, writer: Writer): writer.add_license("apache-2.0") writer.set_tensor_layout_default(TensorLayout.nchw) - model: dict[str, Tensor] = torch.load(input_filepath, map_location="cpu", weights_only=True) + model = load_model(input_filepath) for key, tensor in model.items(): name = key @@ -286,8 +286,7 @@ def convert_birefnet(input_filepath: Path, writer: Writer): writer.add_license("mit") writer.set_tensor_layout_default(TensorLayout.nchw) - weights = safetensors.safe_open(input_filepath, "pt") - model: dict[str, Tensor] = {k: weights.get_tensor(k) for k in weights.keys()} + model = load_model(input_filepath) x = model["bb.layers.0.blocks.0.attn.proj.bias"] if x.shape[0] == 96: @@ -360,7 +359,7 @@ def convert_depth_anything(input_filepath: Path, writer: Writer): writer.add_license("cc-by-nc-4.0") writer.set_tensor_layout_default(TensorLayout.nchw) - model: dict[str, Tensor] = load_model(input_filepath) + model = load_model(input_filepath) if "pretrained.cls_token" in model: print("The converter is written for the transformers (.safetensors) version of the model.") @@ -411,7 +410,7 @@ def convert_migan(input_filepath: Path, writer: Writer): writer.add_license("mit") writer.set_tensor_layout_default(TensorLayout.nchw) - model: dict[str, Tensor] = torch.load(input_filepath, weights_only=True) + model = load_model(input_filepath) if "encoder.b512.fromrgb.weight" in model: writer.add_int32("migan.image_size", 512) diff --git a/src/visp/CMakeLists.txt b/src/visp/CMakeLists.txt index dd176df..398fa8f 100644 --- a/src/visp/CMakeLists.txt +++ b/src/visp/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources(visioncpp PRIVATE arch/migan.cpp arch/mobile-sam.cpp arch/swin.cpp + c-api.cpp image.cpp ml.cpp nn.cpp diff --git a/src/visp/c-api.cpp b/src/visp/c-api.cpp new file mode 100644 index 0000000..6846c2d --- /dev/null +++ b/src/visp/c-api.cpp @@ -0,0 +1,253 @@ +#include "util/string.h" +#include "visp/vision.h" + +using namespace visp; + +thread_local fixed_string<512> _error_string{}; + +void set_error(std::exception const& e) { + _error_string = e.what(); +} + +template +int32_t handle_errors(F&& f) { + try { + f(); + } catch (std::exception const& e) { + set_error(e); + return 0; + } + return 1; +} + +void expect_images(span images, size_t count) { + if (images.size() != count) { + throw except("Expected {} input images, but got {}.", count, images.size()); + } +} + +template +struct model_funcs {}; + +template <> +struct model_funcs { + using model_t = sam_model; + + static sam_model load(char const* filepath, backend_device const& dev) { + return sam_load_model(filepath, dev); + } + static image_data compute(sam_model& m, span inputs, span prompt) { + expect_images(inputs, 1); + sam_encode(m, inputs[0]); + if (prompt.size() == 2) { + return sam_compute(m, i32x2{prompt[0], prompt[1]}); + } else if (prompt.size() == 4) { + return sam_compute(m, box_2d{i32x2{prompt[0], prompt[1]}, i32x2{prompt[2], prompt[3]}}); + } else { + throw except("sam: bad number of arguments ({}), must be 2 or 4", prompt.size()); + } + } +}; + +template <> +struct model_funcs { + using model_t = birefnet_model; + + static birefnet_model load(char const* filepath, backend_device const& dev) { + return birefnet_load_model(filepath, dev); + } + static image_data compute(birefnet_model& m, span inputs, span) { + expect_images(inputs, 1); + return birefnet_compute(m, inputs[0]); + } +}; + +template <> +struct model_funcs { + using model_t = depthany_model; + + static depthany_model load(char const* filepath, backend_device const& dev) { + return depthany_load_model(filepath, dev); + } + static image_data compute(depthany_model& m, span inputs, span) { + expect_images(inputs, 1); + image_data result_f32 = depthany_compute(m, inputs[0]); + image_data normalized = image_normalize(result_f32); + return image_f32_to_u8(normalized, image_format::alpha_u8); + } +}; + +template <> +struct model_funcs { + using model_t = migan_model; + + static migan_model load(char const* filepath, backend_device const& dev) { + return migan_load_model(filepath, dev); + } + static image_data compute(migan_model& m, span inputs, span) { + expect_images(inputs, 2); + if (inputs[1].format != image_format::alpha_u8) { + throw except("migan: second input image (mask) must be alpha_u8 format"); + } + return migan_compute(m, inputs[0], inputs[1]); + } +}; + +template <> +struct model_funcs { + using model_t = esrgan_model; + + static esrgan_model load(char const* filepath, backend_device const& dev) { + return esrgan_load_model(filepath, dev); + } + static image_data compute(esrgan_model& m, span inputs, span) { + expect_images(inputs, 1); + return esrgan_compute(m, inputs[0]); + } +}; + +template +void dispatch_model(model_family family, F&& f) { + switch (family) { + case model_family::sam: f(model_funcs{}); break; + case model_family::birefnet: f(model_funcs{}); break; + case model_family::depth_anything: f(model_funcs{}); break; + case model_family::migan: f(model_funcs{}); break; + case model_family::esrgan: f(model_funcs{}); break; + default: throw visp::exception("Unsupported model family"); + } +} + +struct visp_image_view { + int32_t width; + int32_t height; + int32_t stride; + int32_t format; + void* data; +}; + +void put_image(visp_image_view* out, image_view const& img) { + out->width = img.extent[0]; + out->height = img.extent[1]; + out->stride = img.stride; + out->format = int32_t(img.format); + out->data = (void*)img.data; +} + +void return_image(image_data** out_data, visp_image_view* out_image, image_data&& img) { + *out_data = new image_data(std::move(img)); + put_image(out_image, **out_data); +} + +// +// public C interface + +extern "C" { + +VISP_API char const* visp_get_last_error() { + return _error_string.c_str(); +} + +// image + +VISP_API void visp_image_destroy(image_data* img) { + delete img; +} + +// device + +VISP_API int32_t visp_backend_load_all(char const* dir) { + ggml_backend_load_all_from_path(dir); + return (int32_t)ggml_backend_reg_count(); +} + +VISP_API int32_t visp_device_init(int32_t type, backend_device** out_device) { + return handle_errors([&]() { + if (type == 0) { + *out_device = new backend_device(backend_init()); + } else { + *out_device = new backend_device(backend_init(backend_type(type))); + } + }); +} + +VISP_API void visp_device_destroy(backend_device* d) { + delete d; +} + +VISP_API int32_t visp_device_type(backend_device const* d) { + return int32_t(d->type()); +} + +VISP_API char const* visp_device_name(backend_device const* d) { + ggml_backend_dev_props props; + ggml_backend_dev_get_props(d->device, &props); + return props.name; +} + +VISP_API char const* visp_device_description(backend_device const* d) { + ggml_backend_dev_props props; + ggml_backend_dev_get_props(d->device, &props); + return props.description; +} + +// models + +struct any_model; + +VISP_API int32_t visp_model_detect_family(char const* filepath, int32_t* out_family) { + return handle_errors([&]() { + model_file file = model_load(filepath); + model_family family = model_detect_family(file); + *out_family = int32_t(family); + }); +} + +VISP_API int32_t visp_model_load( + char const* filepath, backend_device const* dev, int32_t arch, any_model** out) { + + return handle_errors([&]() { + model_family family = model_family(arch); + if (family == model_family::count) { + model_file file = model_load(filepath); + family = model_detect_family(file); + } + dispatch_model(family, [&](auto funcs) { + using model_t = typename decltype(funcs)::model_t; + *out = reinterpret_cast(new model_t(funcs.load(filepath, *dev))); + }); + }); +} + +VISP_API void visp_model_destroy(any_model* model, int32_t arch) { + model_family family = model_family(arch); + dispatch_model(family, [&](auto funcs) { + using model_t = typename decltype(funcs)::model_t; + delete reinterpret_cast(model); + }); +} + +VISP_API int32_t visp_model_compute( + any_model* model, + int32_t family, + image_view* inputs, + int32_t n_inputs, + int32_t* args, + int32_t n_args, + visp_image_view* out_image, + image_data** out_data) { + + return handle_errors([&]() { + span input_views(inputs, n_inputs); + span input_args(args, n_args); + + dispatch_model(model_family(family), [&](auto funcs) { + using model_t = typename decltype(funcs)::model_t; + model_t& m = *reinterpret_cast(model); + image_data result = funcs.compute(m, input_views, input_args); + return_image(out_data, out_image, std::move(result)); + }); + }); +} + +} // extern "C" \ No newline at end of file diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index 107cbc6..0336b48 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -59,8 +59,10 @@ backend_device backend_init() { load_ggml_backends(); backend_device b; b.handle.reset(ggml_backend_init_best()); + if (!b.handle) { + throw except("Failed to initialize backend, no suitable device available"); + } b.device = ggml_backend_get_device(b.handle.get()); - ASSERT(b.handle, "Failed to initialize backend"); return b; } diff --git a/src/visp/vision.cpp b/src/visp/vision.cpp index 36d324c..d8a57f4 100644 --- a/src/visp/vision.cpp +++ b/src/visp/vision.cpp @@ -4,6 +4,22 @@ namespace visp { +model_family model_detect_family(model_file const& file) { + std::string_view arch = file.arch(); + if (arch == "mobile-sam") { + return model_family::sam; + } else if (arch == "birefnet") { + return model_family::birefnet; + } else if (arch == "depthanything") { + return model_family::depth_anything; + } else if (arch == "migan") { + return model_family::migan; + } else if (arch == "esrgan") { + return model_family::esrgan; + } + return model_family::count; +} + // // Mobile SAM diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 61b3566..72b41a5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -45,6 +45,16 @@ target_compile_options(vision-workbench PRIVATE ${VISP_COMP_OPTIONS}) target_link_options(vision-workbench PRIVATE ${VISP_LINK_OPTIONS}) target_link_libraries(vision-workbench PRIVATE visioncpp ggml ${VISP_FMT_LINK}) +# +# Python tests + +if(VISP_CI) + set(PYTHON_TESTS_ARGS "--ci") +endif() +add_test(NAME python + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + COMMAND uv run pytest -vs tests ${PYTHON_TESTS_ARGS}) + # # Benchmarks diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..64a8d6f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--ci", + action="store_true", + default=False, + help="Configure tests for continuous integration environment", + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--ci"): + # Filter to keep only tests from test_python_bindings.py + filtered_items = [item for item in items if "test_python_bindings.py" in item.nodeid] + items[:] = filtered_items diff --git a/tests/test_primitives.py b/tests/test_primitives.py index b6f53c8..9724fbc 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -2,7 +2,7 @@ import torch from . import workbench -from .workbench import input_tensor, to_nchw, to_nhwc +from .workbench import input_tensor, to_nchw, to_nhwc, tensors_match def test_linear(): @@ -13,7 +13,7 @@ def test_linear(): result = workbench.invoke_test("linear", x, dict(weight=weight, bias=bias)) expected = torch.nn.functional.linear(x, weight, bias) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) @pytest.mark.parametrize("scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"]) @@ -48,7 +48,7 @@ def test_conv_2d_depthwise(scenario: str, memory_layout: str, batch: str, backen if memory_layout == "nhwc": result = to_nchw(result) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) @pytest.mark.parametrize("scenario", ["3x3", "5x5", "stride2", "nhwc"]) @@ -76,8 +76,7 @@ def test_conv_transpose_2d(scenario: str): if scenario == "nhwc": result = to_nchw(result) - workbench.print_results(result, expected) - assert torch.allclose(result, expected, rtol=1e-2) + assert tensors_match(result, expected, rtol=1e-2) # def test_batch_norm_2d(): @@ -107,7 +106,7 @@ def test_layer_norm(): result = workbench.invoke_test("layer_norm", x, dict(weight=weight, bias=bias)) expected = torch.nn.functional.layer_norm(x, [dim], weight, bias, eps=1e-5) - assert torch.allclose(result, expected, atol=1e-6) + assert tensors_match(result, expected, atol=1e-6) @pytest.mark.parametrize("backend", ["cpu", "vulkan"]) @@ -134,7 +133,7 @@ def test_window_partition(backend: str): result = workbench.invoke_test("sam_window_partition", x, {}, backend=backend) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) @pytest.mark.parametrize("shift", [(0, 2, -1, 0), (0, -2, 0, 3)]) @@ -148,7 +147,7 @@ def test_roll(shift: tuple[int, int, int, int], backend: str): params = dict(s0=shift[3], s1=shift[2], s2=shift[1], s3=shift[0]) result = workbench.invoke_test("roll", x, {}, params, backend) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) @pytest.mark.parametrize("mode", ["bilinear", "bicubic"]) @@ -170,4 +169,4 @@ def test_interpolate(mode: str, align_corners: bool, size: str, scale: float, ba params = dict(mode=mode, h=target[0], w=target[1], align_corners=1 if align_corners else 0) result = workbench.invoke_test("interpolate", x, {}, params, backend) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) diff --git a/tests/test_python_bindings.py b/tests/test_python_bindings.py new file mode 100644 index 0000000..b47bd66 --- /dev/null +++ b/tests/test_python_bindings.py @@ -0,0 +1,84 @@ +import sys +import numpy as np +import pytest +from pathlib import Path +from PIL import Image + +root_dir = Path(__file__).parent.parent +sys.path.insert(0, str(root_dir / "bindings" / "python")) +from visioncpp import Arch, Backend, Device, Model # noqa + +model_dir = root_dir / "models" +image_dir = root_dir / "tests" / "input" +result_dir = root_dir / "tests" / "results" / "python" +result_dir.mkdir(parents=True, exist_ok=True) +ref_dir = root_dir / "tests" / "reference" + +@pytest.fixture +def device(pytestconfig): + if pytestconfig.getoption("ci"): + return Device.init(Backend.cpu) + return Device.init() + + +def compare_images(name: str, result: Image.Image, tolerance: float = 0.015): + name = f"{name}-gpu.png" + result.save(str(result_dir / name)) + result = result.convert("RGB") + result_array = np.array(result).astype(np.float32) / 255.0 + + ref_image = Image.open(str(ref_dir / name)).convert("RGB") + ref_array = np.array(ref_image).astype(np.float32) / 255.0 + + if ref_array.shape != result_array.shape: + raise AssertionError( + f"Image shapes do not match: {ref_array.shape} vs {result_array.shape}" + ) + rmse = np.sqrt(np.mean((ref_array - result_array) ** 2)) + if rmse > tolerance: + raise AssertionError(f"Images differ: RMSE={rmse} exceeds tolerance={tolerance}") + + +def test_sam(device: Device): + model = Model.load(model_dir / "MobileSAM-F16.gguf", device) + assert model.arch is Arch.sam + + img = Image.open(str(image_dir / "cat-and-hat.jpg")) + result_box = model.compute(img, args=[180, 110, 505, 330]) + result_point = model.compute(img, args=[200, 300]) + compare_images("mobile_sam-box", result_box) + compare_images("mobile_sam-point", result_point) + +def test_birefnet(device: Device): + model = Model.load(model_dir / "BiRefNet-lite-F16.gguf", device) + assert model.arch is Arch.birefnet + + img = Image.open(str(image_dir / "wardrobe.jpg")) + result = model.compute(img) + compare_images("birefnet", result) + +def test_depth_anything(device: Device): + model = Model.load(model_dir / "Depth-Anything-V2-Small-F16.gguf", device) + assert model.arch is Arch.depth_anything + + img = Image.open(str(image_dir / "wardrobe.jpg")) + result = model.compute(img) + compare_images("depth-anything", result) + +def test_migan(device: Device): + model = Model.load(model_dir / "MIGAN-512-places2-F16.gguf", device) + assert model.arch is Arch.migan + + img = Image.open(str(image_dir / "bench-image.jpg")).convert("RGBA") + mask = Image.open(str(image_dir / "bench-mask.png")) + result = model.compute(img, mask) + result = Image.alpha_composite(img, result) + compare_images("migan", result) + +def test_esrgan(device: Device): + model = Model.load(str(model_dir / "RealESRGAN-x4plus_anime-6B-F16.gguf"), device) + assert model.arch is Arch.esrgan + + img = Image.open(str(image_dir / "vase-and-bowl.jpg")) + result = model.compute(img) + compare_images("esrgan", result) diff --git a/tests/workbench.py b/tests/workbench.py index 1c3950a..d5ab748 100644 --- a/tests/workbench.py +++ b/tests/workbench.py @@ -1,10 +1,11 @@ import ctypes -from functools import reduce -from typing import Mapping import torch import os - +import platform +from functools import reduce +from typing import Mapping from pathlib import Path +from torch import Tensor float_ptr = ctypes.POINTER(ctypes.c_float) @@ -90,20 +91,40 @@ def encode_params(params: Mapping[str, str | int | float]): os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" root_dir = Path(__file__).parent.parent -bin_dir = root_dir / "build" / "bin" - -lib = ctypes.CDLL(str(bin_dir / "vision-workbench.dll")) -lib.visp_workbench.argtypes = [ - ctypes.c_char_p, - ctypes.POINTER(RawTensor), - ctypes.c_int32, - ctypes.POINTER(RawParam), - ctypes.c_int32, - ctypes.POINTER(ctypes.POINTER(RawTensor)), - ctypes.POINTER(ctypes.c_int32), - ctypes.c_int32, -] -lib.visp_workbench.restype = ctypes.c_int32 + +def _load_library(): + system = platform.system().lower() + if system == "windows": + prefix = "" + suffix = ".dll" + libdir = "bin" + elif system == "darwin": + prefix = "lib" + suffix = ".dylib" + libdir = "lib" + else: # assume Linux / Unix + prefix = "lib" + suffix = ".so" + libdir = "lib" + lib_path = root_dir / "build" / libdir / f"{prefix}vision-workbench{suffix}" + return ctypes.CDLL(str(lib_path)) + +try: + lib = _load_library() + + lib.visp_workbench.argtypes = [ + ctypes.c_char_p, + ctypes.POINTER(RawTensor), + ctypes.c_int32, + ctypes.POINTER(RawParam), + ctypes.c_int32, + ctypes.POINTER(ctypes.POINTER(RawTensor)), + ctypes.POINTER(ctypes.c_int32), + ctypes.c_int32, + ] + lib.visp_workbench.restype = ctypes.c_int32 +except OSError as e: + print(f"Error loading vision-workbench library: {e}") def invoke_test( @@ -168,11 +189,14 @@ def to_nhwc(tensor: torch.Tensor): return tensor.permute(0, 2, 3, 1).contiguous() -def to_nchw(tensor: torch.Tensor): +def to_nchw(tensor: Tensor|list[Tensor]|None): + assert tensor is not None + if isinstance(tensor, list): + return [t.permute(0, 3, 1, 2).contiguous() for t in tensor] return tensor.permute(0, 3, 1, 2).contiguous() -def convert_to_nhwc(state: dict[str, torch.Tensor], key=""): +def convert_to_nhwc(state: dict[str, Tensor], key=""): for k, v in state.items(): is_conv = ( v.ndim == 4 @@ -250,7 +274,25 @@ def fuse_conv_2d_batch_norm( return False # no match -def print_results(result: torch.Tensor, expected: torch.Tensor): +def print_results(result: Tensor, expected: Tensor): print("\ntorch seed:", torch.initial_seed()) print("\nresult -----", result, sep="\n") print("\nexpected ---", expected, sep="\n") + + +def tensors_match( + result: Tensor | list[Tensor] | None, + expected: Tensor | list[Tensor], + rtol=1e-3, + atol=1e-5, + show=False, +): + assert result is not None, "No result returned" + if isinstance(expected, list): + assert isinstance(result, list), "Result is not a list" + assert len(result) == len(expected), f"Expected {len(expected)} tensors, got {len(result)}" + return all(tensors_match(r, e, rtol, atol, show) for r, e in zip(result, expected)) + assert isinstance(result, Tensor), "Result is not a tensor" + if show: + print_results(result, expected) + return torch.allclose(result, expected, rtol=rtol, atol=atol)