Skip to content

Commit 8d283c6

Browse files
committed
python: fill out the rest of the bindings and support all models
1 parent c587178 commit 8d283c6

File tree

5 files changed

+251
-79
lines changed

5 files changed

+251
-79
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .vision import * # noqa
1+
from ._lib import Error # noqa
2+
from .vision import * # noqa

bindings/python/visioncpp/_lib.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import platform
33
from pathlib import Path
44
from ctypes import c_byte, c_char_p, c_void_p, c_int32, POINTER
5+
from PIL import Image
56

67

78
class Error(Exception):
@@ -23,7 +24,7 @@ def _image_format_to_string(format: int):
2324
def _image_mode_from_string(mode: str):
2425
match mode:
2526
case "RGBA":
26-
return 0, 4 # visp::image_format, bytes per pixel
27+
return 0, 4 # visp::image_format, bytes per pixel
2728
case "RGB":
2829
return 3, 3
2930
case "L":
@@ -48,17 +49,13 @@ def from_bytes(width: int, height: int, stride: int, format: int, data: bytes):
4849

4950
@staticmethod
5051
def from_pil_image(image):
51-
from PIL import Image
52-
5352
assert isinstance(image, Image.Image), "Expected a PIL Image"
5453
data = image.tobytes()
5554
w, h = image.size
5655
format, bpp = _image_mode_from_string(image.mode)
5756
return ImageView.from_bytes(w, h, w * bpp, format, data)
5857

5958
def to_pil_image(self):
60-
from PIL import Image
61-
6259
mode = _image_format_to_string(self.format)
6360
size = self.height * self.stride
6461
data = memoryview((c_byte * size).from_address(self.data))
@@ -145,8 +142,17 @@ def init():
145142
lib.visp_model_destroy.argtypes = [Model, c_int32]
146143
lib.visp_model_destroy.restype = None
147144

148-
lib.visp_esrgan_compute.argtypes = [Model, ImageView, POINTER(ImageView), POINTER(ImageData)]
149-
lib.visp_esrgan_compute.restype = c_int32
145+
lib.visp_model_compute.argtypes = [
146+
Model,
147+
c_int32,
148+
POINTER(ImageView),
149+
c_int32,
150+
POINTER(c_int32),
151+
c_int32,
152+
POINTER(ImageView),
153+
POINTER(ImageData),
154+
]
155+
lib.visp_model_compute.restype = c_int32
150156

151157
return lib
152158

bindings/python/visioncpp/vision.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from ctypes import CDLL, byref
1+
from ctypes import CDLL, byref, c_int32
22
from enum import Enum
33
from pathlib import Path
4-
from typing import NamedTuple
4+
from typing import NamedTuple, Sequence
55
import PIL.Image
66

77
from . import _lib as lib
@@ -84,41 +84,58 @@ class Arch(Enum):
8484
unknown = 5
8585

8686

87-
class ESRGAN:
88-
arch = Arch.esrgan
89-
87+
class Model:
9088
@classmethod
91-
def load(cls, path: str | Path, device: Device):
89+
def load(cls, path: str | Path, device: Device, arch=Arch.unknown):
9290
api = get_lib()
9391
handle = lib.Model()
92+
path_str = lib.path_to_char_p(path)
93+
if arch is Arch.unknown:
94+
arch_v = c_int32()
95+
check(api.visp_model_detect_family(path_str, byref(arch_v)))
96+
arch = Arch(arch_v.value)
97+
else:
98+
arch_v = arch.value
99+
100+
check(api.visp_model_load(path_str, device._handle, arch_v, byref(handle)))
101+
return cls(api, handle, arch)
102+
103+
def compute(self, *images: Image, args: Sequence[int] | None = None):
104+
if args is None:
105+
args = []
106+
107+
in_views = [_img_view(i) for i in images]
108+
in_views_array = (lib.ImageView * len(in_views))(*in_views)
109+
args_array = (lib.c_int32 * len(args))(*args)
110+
out_view = lib.ImageView()
111+
out_data = lib.ImageData()
94112
check(
95-
api.visp_model_load(
96-
lib.path_to_char_p(path), device._handle, cls.arch.value, byref(handle)
113+
self._api.visp_model_compute(
114+
self._handle,
115+
self.arch.value,
116+
in_views_array,
117+
len(in_views_array),
118+
args_array,
119+
len(args_array),
120+
byref(out_view),
121+
byref(out_data),
97122
)
98123
)
99-
return cls(api, handle)
100-
101-
def compute(self, image: Image):
102-
api = self._api
103-
in_view = _img_view(image)
104-
out_view = lib.ImageView()
105-
out_data = lib.ImageData()
106-
check(api.visp_esrgan_compute(self._handle, in_view, byref(out_view), byref(out_data)))
107124
try:
108125
result = lib.ImageView.to_pil_image(out_view)
109126
finally:
110-
api.visp_image_destroy(out_data)
127+
self._api.visp_image_destroy(out_data)
111128
return result
112129

113-
def __init__(self, api: CDLL, handle: lib.Handle):
130+
def __init__(self, api: CDLL, handle: lib.Handle, arch: Arch):
131+
self.arch = arch
114132
self._api = api
115133
self._handle = handle
116134

117135
def __del__(self):
118136
self._api.visp_model_destroy(self._handle, self.arch.value)
119137

120138

121-
122139
def _img_view(i: Image) -> lib.ImageView:
123140
if isinstance(i, PIL.Image.Image):
124141
return lib.ImageView.from_pil_image(i)

src/visp/c-api.cpp

Lines changed: 133 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "util/string.h"
22
#include "visp/vision.h"
33

4-
54
using namespace visp;
65

76
thread_local fixed_string<512> _error_string{};
@@ -14,20 +13,110 @@ template <typename F>
1413
int32_t handle_errors(F&& f) {
1514
try {
1615
f();
17-
} catch (std::exception const& e) {
16+
} catch (std::exception const& e) {
1817
set_error(e);
1918
return 0;
2019
}
2120
return 1;
2221
}
2322

24-
extern "C" {
25-
26-
VISP_API char const* visp_get_last_error() {
27-
return _error_string.c_str();
23+
void expect_images(span<image_view> images, size_t count) {
24+
if (images.size() != count) {
25+
throw except("Expected {} input images, but got {}.", count, images.size());
26+
}
2827
}
2928

30-
// image
29+
template <model_family f>
30+
struct model_funcs {};
31+
32+
template <>
33+
struct model_funcs<model_family::sam> {
34+
using model_t = sam_model;
35+
36+
static sam_model load(char const* filepath, backend_device const& dev) {
37+
return sam_load_model(filepath, dev);
38+
}
39+
static image_data compute(sam_model& m, span<image_view> inputs, span<int> prompt) {
40+
expect_images(inputs, 1);
41+
sam_encode(m, inputs[0]);
42+
if (prompt.size() == 2) {
43+
return sam_compute(m, i32x2{prompt[0], prompt[1]});
44+
} else if (prompt.size() == 4) {
45+
return sam_compute(m, box_2d{i32x2{prompt[0], prompt[1]}, i32x2{prompt[2], prompt[3]}});
46+
} else {
47+
throw except("sam: bad number of arguments ({}), must be 2 or 4", prompt.size());
48+
}
49+
}
50+
};
51+
52+
template <>
53+
struct model_funcs<model_family::birefnet> {
54+
using model_t = birefnet_model;
55+
56+
static birefnet_model load(char const* filepath, backend_device const& dev) {
57+
return birefnet_load_model(filepath, dev);
58+
}
59+
static image_data compute(birefnet_model& m, span<image_view> inputs, span<int>) {
60+
expect_images(inputs, 1);
61+
return birefnet_compute(m, inputs[0]);
62+
}
63+
};
64+
65+
template <>
66+
struct model_funcs<model_family::depth_anything> {
67+
using model_t = depthany_model;
68+
69+
static depthany_model load(char const* filepath, backend_device const& dev) {
70+
return depthany_load_model(filepath, dev);
71+
}
72+
static image_data compute(depthany_model& m, span<image_view> inputs, span<int>) {
73+
expect_images(inputs, 1);
74+
image_data result_f32 = depthany_compute(m, inputs[0]);
75+
image_data normalized = image_normalize(result_f32);
76+
return image_f32_to_u8(normalized, image_format::alpha_u8);
77+
}
78+
};
79+
80+
template <>
81+
struct model_funcs<model_family::migan> {
82+
using model_t = migan_model;
83+
84+
static migan_model load(char const* filepath, backend_device const& dev) {
85+
return migan_load_model(filepath, dev);
86+
}
87+
static image_data compute(migan_model& m, span<image_view> inputs, span<int>) {
88+
expect_images(inputs, 2);
89+
if (inputs[1].format != image_format::alpha_u8) {
90+
throw except("migan: second input image (mask) must be alpha_u8 format");
91+
}
92+
return migan_compute(m, inputs[0], inputs[1]);
93+
}
94+
};
95+
96+
template <>
97+
struct model_funcs<model_family::esrgan> {
98+
using model_t = esrgan_model;
99+
100+
static esrgan_model load(char const* filepath, backend_device const& dev) {
101+
return esrgan_load_model(filepath, dev);
102+
}
103+
static image_data compute(esrgan_model& m, span<image_view> inputs, span<int>) {
104+
expect_images(inputs, 1);
105+
return esrgan_compute(m, inputs[0]);
106+
}
107+
};
108+
109+
template <typename F>
110+
void dispatch_model(model_family family, F&& f) {
111+
switch (family) {
112+
case model_family::sam: f(model_funcs<model_family::sam>{}); break;
113+
case model_family::birefnet: f(model_funcs<model_family::birefnet>{}); break;
114+
case model_family::depth_anything: f(model_funcs<model_family::depth_anything>{}); break;
115+
case model_family::migan: f(model_funcs<model_family::migan>{}); break;
116+
case model_family::esrgan: f(model_funcs<model_family::esrgan>{}); break;
117+
default: throw visp::exception("Unsupported model family");
118+
}
119+
}
31120

32121
struct visp_image_view {
33122
int32_t width;
@@ -50,6 +139,17 @@ void return_image(image_data** out_data, visp_image_view* out_image, image_data&
50139
put_image(out_image, **out_data);
51140
}
52141

142+
//
143+
// public C interface
144+
145+
extern "C" {
146+
147+
VISP_API char const* visp_get_last_error() {
148+
return _error_string.c_str();
149+
}
150+
151+
// image
152+
53153
VISP_API void visp_image_destroy(image_data* img) {
54154
delete img;
55155
}
@@ -107,55 +207,41 @@ VISP_API int32_t visp_model_load(
107207
model_file file = model_load(filepath);
108208
family = model_detect_family(file);
109209
}
110-
switch (family) {
111-
case model_family::sam: {
112-
sam_model model = sam_load_model(filepath, *dev);
113-
*out = reinterpret_cast<any_model*>(new sam_model(std::move(model)));
114-
break;
115-
}
116-
case model_family::birefnet: {
117-
birefnet_model model = birefnet_load_model(filepath, *dev);
118-
*out = reinterpret_cast<any_model*>(new birefnet_model(std::move(model)));
119-
break;
120-
}
121-
case model_family::depth_anything: {
122-
depthany_model model = depthany_load_model(filepath, *dev);
123-
*out = reinterpret_cast<any_model*>(new depthany_model(std::move(model)));
124-
break;
125-
}
126-
case model_family::migan: {
127-
migan_model model = migan_load_model(filepath, *dev);
128-
*out = reinterpret_cast<any_model*>(new migan_model(std::move(model)));
129-
break;
130-
}
131-
case model_family::esrgan: {
132-
esrgan_model model = esrgan_load_model(filepath, *dev);
133-
*out = reinterpret_cast<any_model*>(new esrgan_model(std::move(model)));
134-
break;
135-
}
136-
default: throw visp::exception("Invalid model family");
137-
}
210+
dispatch_model(family, [&](auto funcs) {
211+
using model_t = typename decltype(funcs)::model_t;
212+
*out = reinterpret_cast<any_model*>(new model_t(funcs.load(filepath, *dev)));
213+
});
138214
});
139215
}
140216

141217
VISP_API void visp_model_destroy(any_model* model, int32_t arch) {
142218
model_family family = model_family(arch);
143-
switch (family) {
144-
case model_family::sam: delete reinterpret_cast<sam_model*>(model); break;
145-
case model_family::birefnet: delete reinterpret_cast<birefnet_model*>(model); break;
146-
case model_family::depth_anything: delete reinterpret_cast<depthany_model*>(model); break;
147-
case model_family::migan: delete reinterpret_cast<migan_model*>(model); break;
148-
case model_family::esrgan: delete reinterpret_cast<esrgan_model*>(model); break;
149-
default: fprintf(stderr, "Invalid model family: %d\n", int(family)); break;
150-
}
219+
dispatch_model(family, [&](auto funcs) {
220+
using model_t = typename decltype(funcs)::model_t;
221+
delete reinterpret_cast<model_t*>(model);
222+
});
151223
}
152224

153-
VISP_API int32_t visp_esrgan_compute(
154-
esrgan_model* model, image_view in_image, visp_image_view* out_image, image_data** out_data) {
225+
VISP_API int32_t visp_model_compute(
226+
any_model* model,
227+
int32_t family,
228+
image_view* inputs,
229+
int32_t n_inputs,
230+
int32_t* args,
231+
int32_t n_args,
232+
visp_image_view* out_image,
233+
image_data** out_data) {
155234

156235
return handle_errors([&]() {
157-
image_data result = esrgan_compute(*model, in_image);
158-
return_image(out_data, out_image, std::move(result));
236+
span<image_view> input_views(inputs, n_inputs);
237+
span<int32_t> input_args(args, n_args);
238+
239+
dispatch_model(model_family(family), [&](auto funcs) {
240+
using model_t = typename decltype(funcs)::model_t;
241+
model_t& m = *reinterpret_cast<model_t*>(model);
242+
image_data result = funcs.compute(m, input_views, input_args);
243+
return_image(out_data, out_image, std::move(result));
244+
});
159245
});
160246
}
161247

0 commit comments

Comments
 (0)