Skip to content

Commit 653248a

Browse files
committed
ml: add model_file::float_type() which reads type from GGUF metadata
1 parent d381eaf commit 653248a

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

include/visp/ml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ struct model_file {
8383

8484
VISP_API int64_t n_tensors() const;
8585
VISP_API std::string_view arch() const;
86+
VISP_API ggml_type float_type() const;
8687
VISP_API tensor_data_layout tensor_layout() const;
8788

8889
VISP_API int64_t key(char const* name) const;

src/cli/cli.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,13 @@ std::tuple<model_file, model_weights> load_model_weights(
262262
preferred_layout = file.tensor_layout();
263263
}
264264
model_transfer(file, weights, dev, dev.preferred_float_type(), preferred_layout);
265-
266265
printf("done (%s)\n", t.elapsed_str());
267-
printf("- float type: %s\n", ggml_type_name(weights.float_type()));
266+
267+
ggml_type ftype = file.float_type();
268+
if (ftype == GGML_TYPE_COUNT) {
269+
ftype = weights.float_type();
270+
}
271+
printf("- float type: %s\n", ggml_type_name(ftype));
268272
if (preferred_layout != tensor_data_layout::unknown) {
269273
printf("- tensor layout: %s\n", to_string(preferred_layout));
270274
}

src/visp/ml.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ std::string_view model_file::arch() const {
227227
return get_string("general.architecture");
228228
}
229229

230+
ggml_type model_file::float_type() const {
231+
if (int64_t key_id = gguf_find_key(gguf.get(), "general.file_type"); key_id != -1) {
232+
if (gguf_get_kv_type(gguf.get(), key_id) == GGUF_TYPE_UINT32) {
233+
return (ggml_type)gguf_get_val_u32(gguf.get(), key_id);
234+
}
235+
}
236+
return GGML_TYPE_COUNT;
237+
}
238+
230239
tensor_data_layout model_file::tensor_layout() const {
231240
fixed_string<64> str;
232241
int64_t key = gguf_find_key(gguf.get(), format(str, "{}.tensor_data_layout", arch()));

0 commit comments

Comments
 (0)