Skip to content

Commit 25dc418

Browse files
committed
ml: extend backend_type to allow selecting specific backends in the future
1 parent 653248a commit 25dc418

File tree

4 files changed

+53
-9
lines changed

4 files changed

+53
-9
lines changed

include/visp/ml.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@ enum tensor_data_layout { unknown, whcn, cwhn };
2929
//
3030
// Backend device - represents the compute hardware
3131

32-
enum class backend_type { cpu = 1, gpu = 2 };
32+
enum class backend_type {
33+
cpu = 1,
34+
gpu = 2,
35+
vulkan = gpu | 1 << 8,
36+
};
37+
38+
constexpr bool operator&(backend_type a, backend_type b);
39+
VISP_API std::string_view to_string(backend_type);
3340

3441
// True if the backend library is loaded and has at least one supported device.
3542
VISP_API bool backend_is_available(backend_type);
@@ -283,6 +290,10 @@ VISP_API tensor interpolate(model_ref const&, tensor x, i64x2 target, int32_t mo
283290
//
284291
// implementation
285292

293+
constexpr bool operator&(backend_type a, backend_type b) {
294+
return (int(a) & int(b)) != 0;
295+
}
296+
286297
constexpr model_build_flags operator|(model_build_flag lhs, model_build_flag rhs) {
287298
return model_build_flags(uint32_t(lhs) | uint32_t(rhs));
288299
}

src/visp/ml.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ namespace visp {
1212
//
1313
// backend
1414

15+
std::string_view to_string(backend_type type) {
16+
switch (type) {
17+
case backend_type::cpu: return "cpu";
18+
case backend_type::gpu: return "gpu";
19+
case backend_type::vulkan: return "vulkan";
20+
default: return "unknown";
21+
}
22+
}
23+
1524
bool load_ggml_backends() {
1625
static const bool loaded = []() {
1726
if (ggml_backend_reg_count() > 0) {
@@ -37,6 +46,10 @@ bool backend_is_available(backend_type type) {
3746
case backend_type::gpu:
3847
return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
3948
ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr;
49+
case backend_type::vulkan: {
50+
ggml_backend_reg_t reg = ggml_backend_reg_by_name("Vulkan");
51+
return reg && ggml_backend_reg_dev_count(reg) > 0;
52+
}
4053
default: ASSERT(false, "Invalid backend type");
4154
}
4255
return false;
@@ -60,6 +73,7 @@ backend_device backend_init(backend_type type) {
6073
b.handle.reset(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr));
6174
break;
6275
case backend_type::gpu:
76+
case backend_type::vulkan:
6377
b.handle.reset(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr));
6478
if (!b.handle) {
6579
b.handle.reset(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU, nullptr));
@@ -82,15 +96,21 @@ backend_type backend_device::type() const {
8296
switch (ggml_backend_dev_type(dev)) {
8397
case GGML_BACKEND_DEVICE_TYPE_CPU: return backend_type::cpu;
8498
case GGML_BACKEND_DEVICE_TYPE_GPU:
85-
case GGML_BACKEND_DEVICE_TYPE_IGPU: return backend_type::gpu;
99+
case GGML_BACKEND_DEVICE_TYPE_IGPU: {
100+
std::string_view dev_name = ggml_backend_dev_name(dev);
101+
if (dev_name.find("Vulkan") != std::string_view::npos) {
102+
return backend_type::vulkan;
103+
}
104+
return backend_type::gpu;
105+
}
86106
default: ASSERT(false, "Unsupported backend device type"); return backend_type::cpu;
87107
}
88108
}
89109

90110
typedef bool (*ggml_backend_dev_supports_f16_t)(ggml_backend_dev_t);
91111

92112
ggml_type backend_device::preferred_float_type() const {
93-
if (type() == backend_type::cpu) {
113+
if (type() & backend_type::cpu) {
94114
return GGML_TYPE_F32; // not all operations support F16
95115
} else {
96116
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(device);
@@ -105,7 +125,7 @@ ggml_type backend_device::preferred_float_type() const {
105125
}
106126

107127
tensor_data_layout backend_device::preferred_layout() const {
108-
if (type() == backend_type::cpu) {
128+
if (type() & backend_type::cpu) {
109129
return tensor_data_layout::cwhn;
110130
}
111131
return tensor_data_layout::unknown; // no preference, keep model weight layout
@@ -120,7 +140,10 @@ size_t backend_device::total_memory() const {
120140

121141
size_t backend_device::max_alloc() const {
122142
const size_t vulkan_max = 4 * 1024 * 1024 * 1024ULL; // TODO: query from backend
123-
return type() == backend_type::cpu ? SIZE_MAX : vulkan_max;
143+
switch (type()) {
144+
case backend_type::vulkan: return vulkan_max;
145+
default: return SIZE_MAX;
146+
}
124147
}
125148

126149
void backend_set_n_threads(backend_device& b, int n_threads) {
@@ -154,7 +177,8 @@ model_build_flags backend_default_flags(backend_type type) {
154177
case backend_type::cpu:
155178
return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition |
156179
flash_attn_flag(false);
157-
case backend_type::gpu: return flash_attn_flag(true);
180+
case backend_type::gpu:
181+
case backend_type::vulkan: return flash_attn_flag(true);
158182
}
159183
return {};
160184
}

tests/benchmark.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bench_timings run_benchmark(
3333
int iterations,
3434
std::vector<input_transfer> const& transfers = {}) {
3535

36-
if (backend.type() == backend_type::gpu) {
36+
if (backend.type() & backend_type::gpu) {
3737
iterations *= 4;
3838
}
3939

@@ -139,10 +139,12 @@ backend_device initialize_backend(std::string_view backend_type) {
139139
backend_device cpu = backend_init(backend_type::cpu);
140140
backend_set_n_threads(cpu, (int)std::thread::hardware_concurrency());
141141
return cpu;
142+
} else if (backend_type == "vulkan") {
143+
return backend_init(backend_type::vulkan);
142144
} else if (backend_type == "gpu") {
143145
return backend_init(backend_type::gpu);
144146
} else {
145-
throw std::invalid_argument("Invalid backend type. Use 'cpu' or 'gpu'.");
147+
throw std::invalid_argument("Invalid backend type. Use 'cpu', 'gpu' or 'vulkan'.");
146148
}
147149
}
148150

@@ -159,7 +161,7 @@ bench_result benchmark_model(
159161
bench_result result;
160162
result.arch = arch;
161163
result.model = model;
162-
result.backend = backend.type() == backend_type::cpu ? "cpu" : "gpu";
164+
result.backend = to_string(backend.type());
163165

164166
auto select_model = [&](std::string_view model, std::string_view fallback) {
165167
if (model.empty()) {

tests/test-ml.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
namespace visp {
77

8+
VISP_TEST(backend_available) {
9+
CHECK(backend_is_available(backend_type::cpu));
10+
if (backend_is_available(backend_type::gpu)) {
11+
CHECK(backend_is_available(backend_type::vulkan));
12+
}
13+
}
14+
815
VISP_TEST(model_transfer_type_conversion) {
916
model_weights src = model_init(2);
1017

0 commit comments

Comments
 (0)