Skip to content

Commit daad20e

Browse files
committed
tests: make benchmarks run a fixed amount of time instead of iterations
1 parent 25dc418 commit daad20e

File tree

1 file changed

+95
-40
lines changed

1 file changed

+95
-40
lines changed

tests/benchmark.cpp

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,36 @@
1212
#include <thread>
1313
#include <vector>
1414

15-
using namespace visp;
15+
namespace visp {
16+
using clock = std::chrono::high_resolution_clock;
17+
using duration_ms = std::chrono::duration<double, std::milli>;
18+
19+
char const* usage = R"(
20+
Usage: vision-bench [-m <model1>] [-m <model2> ...] [options]
21+
22+
Run benchmarks on one or more vision models and print results as table.
23+
If no model is specified, benchmarks all supported models.
24+
25+
Options:
26+
-m, --model <arch> Model architecture (sam, birefnet, depthany, ...)
27+
-m, --model <arch:file> Specific model file, eg. "birefnet:BiRefNet-F16.gguf"
28+
-b, --backend <cpu|gpu> Backend type (default: all backends)
29+
--timeout <seconds> Benchmark timeout in seconds (default: 10)
30+
--min-iterations <n> Minimum benchmark iterations (default: 4)
31+
--max-iterations <n> Maximum benchmark iterations (default: 100)
32+
)";
33+
34+
struct bench_args {
35+
duration_ms timeout = duration_ms(10000);
36+
int min_iterations = 4;
37+
int max_iterations = 100;
38+
};
1639

1740
struct bench_timings {
18-
double mean = 0.0;
19-
double stdev = 0.0;
41+
duration_ms total;
42+
duration_ms mean;
43+
duration_ms stdev;
44+
int iterations = 0;
2045
};
2146

2247
struct input_transfer {
@@ -30,38 +55,42 @@ struct input_transfer {
3055
bench_timings run_benchmark(
3156
compute_graph& graph,
3257
backend_device& backend,
33-
int iterations,
58+
bench_args const& args,
3459
std::vector<input_transfer> const& transfers = {}) {
3560

36-
if (backend.type() & backend_type::gpu) {
37-
iterations *= 4;
38-
}
39-
4061
std::vector<double> timings;
41-
timings.reserve(iterations);
62+
timings.reserve(args.max_iterations);
4263

4364
compute(graph, backend); // Warm-up
4465

45-
for (int i = 0; i < iterations; ++i) {
46-
auto start = std::chrono::high_resolution_clock::now();
66+
auto start = clock::now();
67+
int i = 0;
68+
for (i = 0; i < args.max_iterations; ++i) {
69+
auto start_iteration = clock::now();
4770

4871
for (const auto& transfer : transfers) {
4972
transfer_to_backend(transfer.x, transfer.data);
5073
}
5174
compute(graph, backend);
5275

53-
auto end = std::chrono::high_resolution_clock::now();
54-
std::chrono::duration<double, std::milli> elapsed = end - start;
76+
auto end = clock::now();
77+
duration_ms elapsed = end - start_iteration;
5578
timings.push_back(elapsed.count());
79+
80+
if (i >= args.min_iterations && (end - start) >= args.timeout) {
81+
i++; // loop counter -> total runs
82+
break;
83+
}
5684
}
5785

86+
duration_ms total = clock::now() - start;
5887
double mean = std::accumulate(timings.begin(), timings.end(), 0.0) / timings.size();
5988
double sq_sum = std::inner_product(timings.begin(), timings.end(), timings.begin(), 0.0);
6089
double stdev = std::sqrt(sq_sum / timings.size() - mean * mean);
61-
return {mean, stdev};
90+
return {total, duration_ms(mean), duration_ms(stdev), i};
6291
}
6392

64-
bench_timings benchmark_sam(path model_path, backend_device& backend) {
93+
bench_timings benchmark_sam(path model_path, backend_device& backend, bench_args const& args) {
6594
path input_path = test_dir().input / "cat-and-hat.jpg";
6695

6796
sam_model model = sam_load_model(model_path.string().c_str(), backend);
@@ -70,41 +99,44 @@ bench_timings benchmark_sam(path model_path, backend_device& backend) {
7099

71100
sam_encode(model, image_view(input));
72101
bench_timings encoder_timings = run_benchmark(
73-
model.encoder, backend, 16, {{model.input_image, input_data}});
102+
model.encoder, backend, args, {{model.input_image, input_data}});
74103

75104
sam_compute(model, i32x2{200, 300});
76-
bench_timings decoder_timings = run_benchmark(model.decoder, backend, 50);
105+
bench_timings decoder_timings = run_benchmark(model.decoder, backend, args);
77106

78107
return {
79-
encoder_timings.mean + decoder_timings.mean,
80-
std::sqrt(
81-
encoder_timings.stdev * encoder_timings.stdev +
82-
decoder_timings.stdev * decoder_timings.stdev)};
108+
encoder_timings.total + decoder_timings.total, encoder_timings.mean + decoder_timings.mean,
109+
duration_ms(
110+
std::sqrt(
111+
encoder_timings.stdev.count() * encoder_timings.stdev.count() +
112+
decoder_timings.stdev.count() * decoder_timings.stdev.count())),
113+
encoder_timings.iterations};
83114
}
84115

85-
bench_timings benchmark_birefnet(path model_path, backend_device& backend) {
116+
bench_timings benchmark_birefnet(path model_path, backend_device& backend, bench_args const& args) {
86117
path input_path = test_dir().input / "wardrobe.jpg";
87118

88119
birefnet_model model = birefnet_load_model(model_path.string().c_str(), backend);
89120
image_data input = image_load(input_path.string().c_str());
90121
image_data input_data = birefnet_process_input(input, model.params);
91122

92123
birefnet_compute(model, input);
93-
return run_benchmark(model.graph, backend, 8, {{model.input, input_data}});
124+
return run_benchmark(model.graph, backend, args, {{model.input, input_data}});
94125
}
95126

96-
bench_timings benchmark_depth_anything(path model_path, backend_device& backend) {
127+
bench_timings benchmark_depth_anything(
128+
path model_path, backend_device& backend, bench_args const& args) {
97129
path input_path = test_dir().input / "wardrobe.jpg";
98130

99131
depthany_model model = depthany_load_model(model_path.string().c_str(), backend);
100132
image_data input = image_load(input_path.string().c_str());
101133
depthany_compute(model, input);
102134

103135
image_data input_data = depthany_process_input(input, model.params);
104-
return run_benchmark(model.graph, backend, 12, {{model.input, input_data}});
136+
return run_benchmark(model.graph, backend, args, {{model.input, input_data}});
105137
}
106138

107-
bench_timings benchmark_migan(path model_path, backend_device& backend) {
139+
bench_timings benchmark_migan(path model_path, backend_device& backend, bench_args const& args) {
108140
path image_path = test_dir().input / "bench-image.jpg";
109141
path mask_path = test_dir().input / "bench-mask.png";
110142

@@ -114,10 +146,10 @@ bench_timings benchmark_migan(path model_path, backend_device& backend) {
114146
image_data input_data = migan_process_input(image, mask, model.params);
115147

116148
migan_compute(model, image, mask);
117-
return run_benchmark(model.graph, backend, 32, {{model.input, input_data}});
149+
return run_benchmark(model.graph, backend, args, {{model.input, input_data}});
118150
}
119151

120-
bench_timings benchmark_esrgan(path model_path, backend_device& backend) {
152+
bench_timings benchmark_esrgan(path model_path, backend_device& backend, bench_args const& args) {
121153
path input_path = test_dir().input / "vase-and-bowl.jpg";
122154

123155
esrgan_model model = esrgan_load_model(model_path.string().c_str(), backend);
@@ -131,7 +163,7 @@ bench_timings benchmark_esrgan(path model_path, backend_device& backend) {
131163
model.output = esrgan_generate(m, model.input, model.params);
132164

133165
compute_graph_allocate(graph, backend);
134-
return run_benchmark(graph, backend, 8, {{model.input, input_data}});
166+
return run_benchmark(graph, backend, args, {{model.input, input_data}});
135167
}
136168

137169
backend_device initialize_backend(std::string_view backend_type) {
@@ -156,7 +188,10 @@ struct bench_result {
156188
};
157189

158190
bench_result benchmark_model(
159-
std::string_view arch, std::string_view model, backend_device& backend) {
191+
std::string_view arch,
192+
std::string_view model,
193+
backend_device& backend,
194+
bench_args const& args) {
160195

161196
bench_result result;
162197
result.arch = arch;
@@ -179,23 +214,23 @@ bench_result benchmark_model(
179214

180215
if (arch == "sam") {
181216
path model_path = select_model(model, "MobileSAM-F16.gguf");
182-
result.time = benchmark_sam(model_path, backend);
217+
result.time = benchmark_sam(model_path, backend, args);
183218

184219
} else if (arch == "birefnet") {
185220
path model_path = select_model(model, "BiRefNet-lite-F16.gguf");
186-
result.time = benchmark_birefnet(model_path, backend);
221+
result.time = benchmark_birefnet(model_path, backend, args);
187222

188223
} else if (arch == "depthany") {
189224
path model_path = select_model(model, "Depth-Anything-V2-Small-F16.gguf");
190-
result.time = benchmark_depth_anything(model_path, backend);
225+
result.time = benchmark_depth_anything(model_path, backend, args);
191226

192227
} else if (arch == "migan") {
193228
path model_path = select_model(model, "MIGAN-512-places2-F16.gguf");
194-
result.time = benchmark_migan(model_path, backend);
229+
result.time = benchmark_migan(model_path, backend, args);
195230

196231
} else if (arch == "esrgan") {
197232
path model_path = select_model(model, "RealESRGAN-x4plus_anime-6B-F16.gguf");
198-
result.time = benchmark_esrgan(model_path, backend);
233+
result.time = benchmark_esrgan(model_path, backend, args);
199234

200235
} else {
201236
fprintf(stderr, "Unknown model architecture: %s\n", arch.data());
@@ -215,15 +250,22 @@ void print(fixed_string<128> const& str) {
215250
printf("%s", str.c_str());
216251
}
217252

253+
} // namespace visp
254+
218255
int main(int argc, char** argv) {
256+
using namespace visp;
219257
std::vector<std::pair<std::string_view, std::string_view>> models;
220258
std::vector<std::string_view> backends;
259+
bench_args args;
221260

222261
try {
223262

224263
for (int i = 1; i < argc; ++i) {
225264
std::string_view arg(argv[i]);
226-
if (arg == "-m" || arg == "--model") {
265+
if (arg == "-h" || arg == "--help") {
266+
printf("%s", usage);
267+
return 0;
268+
} else if (arg == "-m" || arg == "--model") {
227269
std::string_view text = next_arg(argc, argv, i);
228270
auto p = text.find(':');
229271
if (p == std::string_view::npos) {
@@ -235,6 +277,12 @@ int main(int argc, char** argv) {
235277
}
236278
} else if (arg == "-b" || arg == "--backend") {
237279
backends.push_back(next_arg(argc, argv, i));
280+
} else if (arg == "--timeout") {
281+
args.timeout = duration_ms(std::stod(next_arg(argc, argv, i)) * 1000);
282+
} else if (arg == "--min-iterations") {
283+
args.min_iterations = std::stoi(next_arg(argc, argv, i));
284+
} else if (arg == "--max-iterations") {
285+
args.max_iterations = std::stoi(next_arg(argc, argv, i));
238286
} else {
239287
throw std::invalid_argument("Unknown argument: " + std::string(arg));
240288
}
@@ -264,22 +312,29 @@ int main(int argc, char** argv) {
264312
backend_device backend_device = initialize_backend(backend);
265313
for (auto&& model : models) {
266314
print(format(
267-
line, "[{: <2}/{: <2}] Running {} on {}...\n", ++i, n_tests, model.first,
315+
line, "[{: <2}/{: <2}] Running {} on {}...", ++i, n_tests, model.first,
268316
backend));
269317

270-
results.push_back(benchmark_model(model.first, model.second, backend_device));
318+
bench_result result = benchmark_model(
319+
model.first, model.second, backend_device, args);
320+
321+
print(format(
322+
line, " finished {} runs in {:.1f} s\n", result.time.iterations,
323+
result.time.total.count() / 1000.0));
324+
results.push_back(result);
271325
}
272326
}
273327

274328
printf("\n");
275329
print(format(
276-
line, "| {: <10} | {: <30} | {: <6} | {: >11} | {: >6} |\n", "Arch", "Model", "Device", "Avg", "Dev"));
330+
line, "| {: <10} | {: <30} | {: <6} | {: >11} | {: >6} |\n", "Arch", "Model", "Device",
331+
"Avg", "Dev"));
277332
printf("|:-----------|:-------------------------------|:-------|------------:|-------:|\n");
278333
for (const auto& result : results) {
279334
auto model = result.model.substr(std::max(int(result.model.length()) - 30, 0));
280335
print(format(
281336
line, "| {: <10} | {: <30} | {: <6} | {:8.1f} ms | {:6.1f} |\n", result.arch, model,
282-
result.backend, result.time.mean, result.time.stdev));
337+
result.backend, result.time.mean.count(), result.time.stdev.count()));
283338
}
284339
printf("\n");
285340
} catch (const std::exception& e) {

0 commit comments

Comments
 (0)