11#include " util/string.h"
22#include " visp/vision.h"
33
4-
54using namespace visp ;
65
76thread_local fixed_string<512 > _error_string{};
@@ -14,20 +13,110 @@ template <typename F>
1413int32_t handle_errors (F&& f) {
1514 try {
1615 f ();
17- } catch (std::exception const & e) {
16+ } catch (std::exception const & e) {
1817 set_error (e);
1918 return 0 ;
2019 }
2120 return 1 ;
2221}
2322
24- extern " C " {
25-
26- VISP_API char const * visp_get_last_error () {
27- return _error_string. c_str ();
23+ void expect_images (span<image_view> images, size_t count) {
24+ if (images. size () != count) {
25+ throw except ( " Expected {} input images, but got {}. " , count, images. size ());
26+ }
2827}
2928
30- // image
29+ template <model_family f>
30+ struct model_funcs {};
31+
32+ template <>
33+ struct model_funcs <model_family::sam> {
34+ using model_t = sam_model;
35+
36+ static sam_model load (char const * filepath, backend_device const & dev) {
37+ return sam_load_model (filepath, dev);
38+ }
39+ static image_data compute (sam_model& m, span<image_view> inputs, span<int > prompt) {
40+ expect_images (inputs, 1 );
41+ sam_encode (m, inputs[0 ]);
42+ if (prompt.size () == 2 ) {
43+ return sam_compute (m, i32x2{prompt[0 ], prompt[1 ]});
44+ } else if (prompt.size () == 4 ) {
45+ return sam_compute (m, box_2d{i32x2{prompt[0 ], prompt[1 ]}, i32x2{prompt[2 ], prompt[3 ]}});
46+ } else {
47+ throw except (" sam: bad number of arguments ({}), must be 2 or 4" , prompt.size ());
48+ }
49+ }
50+ };
51+
52+ template <>
53+ struct model_funcs <model_family::birefnet> {
54+ using model_t = birefnet_model;
55+
56+ static birefnet_model load (char const * filepath, backend_device const & dev) {
57+ return birefnet_load_model (filepath, dev);
58+ }
59+ static image_data compute (birefnet_model& m, span<image_view> inputs, span<int >) {
60+ expect_images (inputs, 1 );
61+ return birefnet_compute (m, inputs[0 ]);
62+ }
63+ };
64+
65+ template <>
66+ struct model_funcs <model_family::depth_anything> {
67+ using model_t = depthany_model;
68+
69+ static depthany_model load (char const * filepath, backend_device const & dev) {
70+ return depthany_load_model (filepath, dev);
71+ }
72+ static image_data compute (depthany_model& m, span<image_view> inputs, span<int >) {
73+ expect_images (inputs, 1 );
74+ image_data result_f32 = depthany_compute (m, inputs[0 ]);
75+ image_data normalized = image_normalize (result_f32);
76+ return image_f32_to_u8 (normalized, image_format::alpha_u8);
77+ }
78+ };
79+
80+ template <>
81+ struct model_funcs <model_family::migan> {
82+ using model_t = migan_model;
83+
84+ static migan_model load (char const * filepath, backend_device const & dev) {
85+ return migan_load_model (filepath, dev);
86+ }
87+ static image_data compute (migan_model& m, span<image_view> inputs, span<int >) {
88+ expect_images (inputs, 2 );
89+ if (inputs[1 ].format != image_format::alpha_u8) {
90+ throw except (" migan: second input image (mask) must be alpha_u8 format" );
91+ }
92+ return migan_compute (m, inputs[0 ], inputs[1 ]);
93+ }
94+ };
95+
96+ template <>
97+ struct model_funcs <model_family::esrgan> {
98+ using model_t = esrgan_model;
99+
100+ static esrgan_model load (char const * filepath, backend_device const & dev) {
101+ return esrgan_load_model (filepath, dev);
102+ }
103+ static image_data compute (esrgan_model& m, span<image_view> inputs, span<int >) {
104+ expect_images (inputs, 1 );
105+ return esrgan_compute (m, inputs[0 ]);
106+ }
107+ };
108+
109+ template <typename F>
110+ void dispatch_model (model_family family, F&& f) {
111+ switch (family) {
112+ case model_family::sam: f (model_funcs<model_family::sam>{}); break ;
113+ case model_family::birefnet: f (model_funcs<model_family::birefnet>{}); break ;
114+ case model_family::depth_anything: f (model_funcs<model_family::depth_anything>{}); break ;
115+ case model_family::migan: f (model_funcs<model_family::migan>{}); break ;
116+ case model_family::esrgan: f (model_funcs<model_family::esrgan>{}); break ;
117+ default : throw visp::exception (" Unsupported model family" );
118+ }
119+ }
31120
32121struct visp_image_view {
33122 int32_t width;
@@ -50,6 +139,17 @@ void return_image(image_data** out_data, visp_image_view* out_image, image_data&
50139 put_image (out_image, **out_data);
51140}
52141
142+ //
143+ // public C interface
144+
145+ extern " C" {
146+
147+ VISP_API char const * visp_get_last_error () {
148+ return _error_string.c_str ();
149+ }
150+
151+ // image
152+
53153VISP_API void visp_image_destroy (image_data* img) {
54154 delete img;
55155}
@@ -107,55 +207,41 @@ VISP_API int32_t visp_model_load(
107207 model_file file = model_load (filepath);
108208 family = model_detect_family (file);
109209 }
110- switch (family) {
111- case model_family::sam: {
112- sam_model model = sam_load_model (filepath, *dev);
113- *out = reinterpret_cast <any_model*>(new sam_model (std::move (model)));
114- break ;
115- }
116- case model_family::birefnet: {
117- birefnet_model model = birefnet_load_model (filepath, *dev);
118- *out = reinterpret_cast <any_model*>(new birefnet_model (std::move (model)));
119- break ;
120- }
121- case model_family::depth_anything: {
122- depthany_model model = depthany_load_model (filepath, *dev);
123- *out = reinterpret_cast <any_model*>(new depthany_model (std::move (model)));
124- break ;
125- }
126- case model_family::migan: {
127- migan_model model = migan_load_model (filepath, *dev);
128- *out = reinterpret_cast <any_model*>(new migan_model (std::move (model)));
129- break ;
130- }
131- case model_family::esrgan: {
132- esrgan_model model = esrgan_load_model (filepath, *dev);
133- *out = reinterpret_cast <any_model*>(new esrgan_model (std::move (model)));
134- break ;
135- }
136- default : throw visp::exception (" Invalid model family" );
137- }
210+ dispatch_model (family, [&](auto funcs) {
211+ using model_t = typename decltype (funcs)::model_t ;
212+ *out = reinterpret_cast <any_model*>(new model_t (funcs.load (filepath, *dev)));
213+ });
138214 });
139215}
140216
141217VISP_API void visp_model_destroy (any_model* model, int32_t arch) {
142218 model_family family = model_family (arch);
143- switch (family) {
144- case model_family::sam: delete reinterpret_cast <sam_model*>(model); break ;
145- case model_family::birefnet: delete reinterpret_cast <birefnet_model*>(model); break ;
146- case model_family::depth_anything: delete reinterpret_cast <depthany_model*>(model); break ;
147- case model_family::migan: delete reinterpret_cast <migan_model*>(model); break ;
148- case model_family::esrgan: delete reinterpret_cast <esrgan_model*>(model); break ;
149- default : fprintf (stderr, " Invalid model family: %d\n " , int (family)); break ;
150- }
219+ dispatch_model (family, [&](auto funcs) {
220+ using model_t = typename decltype (funcs)::model_t ;
221+ delete reinterpret_cast <model_t *>(model);
222+ });
151223}
152224
153- VISP_API int32_t visp_esrgan_compute (
154- esrgan_model* model, image_view in_image, visp_image_view* out_image, image_data** out_data) {
225+ VISP_API int32_t visp_model_compute (
226+ any_model* model,
227+ int32_t family,
228+ image_view* inputs,
229+ int32_t n_inputs,
230+ int32_t * args,
231+ int32_t n_args,
232+ visp_image_view* out_image,
233+ image_data** out_data) {
155234
156235 return handle_errors ([&]() {
157- image_data result = esrgan_compute (*model, in_image);
158- return_image (out_data, out_image, std::move (result));
236+ span<image_view> input_views (inputs, n_inputs);
237+ span<int32_t > input_args (args, n_args);
238+
239+ dispatch_model (model_family (family), [&](auto funcs) {
240+ using model_t = typename decltype (funcs)::model_t ;
241+ model_t & m = *reinterpret_cast <model_t *>(model);
242+ image_data result = funcs.compute (m, input_views, input_args);
243+ return_image (out_data, out_image, std::move (result));
244+ });
159245 });
160246}
161247
0 commit comments