Skip to content

Commit f985f4b

Browse files
committed
feat: support mmap for model loading
1 parent bfbb929 commit f985f4b

File tree

7 files changed

+182
-7
lines changed

7 files changed

+182
-7
lines changed

examples/cli/main.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ struct SDContextParams {
504504
rng_type_t rng_type = CUDA_RNG;
505505
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
506506
bool offload_params_to_cpu = false;
507+
bool use_mmap = false;
507508
bool control_net_cpu = false;
508509
bool clip_on_cpu = false;
509510
bool vae_on_cpu = false;
@@ -639,6 +640,10 @@ struct SDContextParams {
639640
"--offload-to-cpu",
640641
"place the weights in RAM to save VRAM, and automatically load them into VRAM when needed",
641642
true, &offload_params_to_cpu},
643+
{"",
644+
"--use-mmap",
645+
"use mmap to load weights",
646+
true, &use_mmap},
642647
{"",
643648
"--control-net-cpu",
644649
"keep controlnet in cpu (for low vram)",
@@ -874,6 +879,7 @@ struct SDContextParams {
874879
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
875880
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
876881
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
882+
<< " use_mmap: " << (use_mmap ? "true" : "false") << ",\n"
877883
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
878884
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
879885
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
@@ -924,6 +930,7 @@ struct SDContextParams {
924930
prediction,
925931
lora_apply_mode,
926932
offload_params_to_cpu,
933+
use_mmap,
927934
clip_on_cpu,
928935
control_net_cpu,
929936
vae_on_cpu,

model.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ std::string ModelLoader::load_umt5_tokenizer_json() {
13371337
return json_str;
13381338
}
13391339

1340-
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p) {
1340+
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool use_mmap) {
13411341
int64_t process_time_ms = 0;
13421342
std::atomic<int64_t> read_time_ms(0);
13431343
std::atomic<int64_t> memcpy_time_ms(0);
@@ -1387,6 +1387,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
13871387
}
13881388
}
13891389

1390+
std::shared_ptr<MmapWrapper> mmapped;
1391+
if (use_mmap && !is_zip) {
1392+
LOG_DEBUG("using mmap for I/O");
1393+
mmapped = MmapWrapper::create(file_path);
1394+
if (!mmapped) {
1395+
LOG_WARN("failed to memory-map '%s'", file_path.c_str());
1396+
}
1397+
}
1398+
13901399
int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size());
13911400
if (n_threads < 1) {
13921401
n_threads = 1;
@@ -1408,7 +1417,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
14081417
failed = true;
14091418
return;
14101419
}
1411-
} else {
1420+
} else if (!mmapped) {
14121421
file.open(file_path, std::ios::binary);
14131422
if (!file.is_open()) {
14141423
LOG_ERROR("failed to open '%s'", file_path.c_str());
@@ -1461,6 +1470,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
14611470
zip_entry_noallocread(zip, (void*)buf, n);
14621471
}
14631472
zip_entry_close(zip);
1473+
} else if (mmapped) {
1474+
if (!mmapped->copy_data(buf, n, tensor_storage.offset)) {
1475+
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
1476+
failed = true;
1477+
}
14641478
} else {
14651479
file.seekg(tensor_storage.offset);
14661480
file.read(buf, n);
@@ -1580,7 +1594,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
15801594

15811595
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
15821596
std::set<std::string> ignore_tensors,
1583-
int n_threads) {
1597+
int n_threads,
1598+
bool use_mmap) {
15841599
std::set<std::string> tensor_names_in_file;
15851600
std::mutex tensor_names_mutex;
15861601
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
@@ -1623,7 +1638,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
16231638
return true;
16241639
};
16251640

1626-
bool success = load_tensors(on_new_tensor_cb, n_threads);
1641+
bool success = load_tensors(on_new_tensor_cb, n_threads, use_mmap);
16271642
if (!success) {
16281643
LOG_ERROR("load tensors from file failed");
16291644
return false;

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,11 @@ class ModelLoader {
308308
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
309309
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
310310
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
311-
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
311+
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
312312
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
313313
std::set<std::string> ignore_tensors = {},
314-
int n_threads = 0);
314+
int n_threads = 0,
315+
bool use_mmap = false);
315316

316317
std::vector<std::string> get_tensor_names() const {
317318
std::vector<std::string> names;

stable-diffusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ class StableDiffusionGGML {
693693
if (version == VERSION_SVD) {
694694
ignore_tensors.insert("conditioner.embedders.3");
695695
}
696-
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads);
696+
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->use_mmap);
697697
if (!success) {
698698
LOG_ERROR("load tensors from model loader failed");
699699
ggml_free(ctx);
@@ -2478,6 +2478,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
24782478
sd_ctx_params->prediction = PREDICTION_COUNT;
24792479
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
24802480
sd_ctx_params->offload_params_to_cpu = false;
2481+
sd_ctx_params->use_mmap = false;
24812482
sd_ctx_params->keep_clip_on_cpu = false;
24822483
sd_ctx_params->keep_control_net_on_cpu = false;
24832484
sd_ctx_params->keep_vae_on_cpu = false;

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ typedef struct {
176176
enum prediction_t prediction;
177177
enum lora_apply_mode_t lora_apply_mode;
178178
bool offload_params_to_cpu;
179+
bool use_mmap;
179180
bool keep_clip_on_cpu;
180181
bool keep_control_net_on_cpu;
181182
bool keep_vae_on_cpu;

util.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,78 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
109109
}
110110
}
111111

112+
class MmapWrapperImpl : public MmapWrapper {
113+
public:
114+
MmapWrapperImpl(void* data, size_t size, HANDLE hfile, HANDLE hmapping)
115+
: MmapWrapper(data, size), hfile_(hfile), hmapping_(hmapping) {}
116+
117+
~WindowsMmapWrapper() override {
118+
if (data_) {
119+
UnmapViewOfFile(data_);
120+
}
121+
if (hmapping_ != NULL) {
122+
CloseHandle(mapping_handle_);
123+
}
124+
if (hfile_ != INVALID_HANDLE_VALUE) {
125+
CloseHandle(file_handle_);
126+
}
127+
}
128+
129+
private:
130+
HANDLE hfile_;
131+
HANDLE hmapping_;
132+
};
133+
134+
std::shared_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
135+
void* mapped_data = nullptr;
136+
size_t file_size = 0;
137+
138+
HANDLE file_handle = CreateFileA(
139+
filename.c_str(),
140+
GENERIC_READ,
141+
FILE_SHARE_READ,
142+
NULL,
143+
OPEN_EXISTING,
144+
FILE_ATTRIBUTE_NORMAL,
145+
NULL
146+
);
147+
148+
if (file_handle == INVALID_HANDLE_VALUE) {
149+
return nullptr;
150+
}
151+
152+
LARGE_INTEGER size;
153+
if (!GetFileSizeEx(file_handle, &size)) {
154+
CloseHandle(file_handle);
155+
return nullptr;
156+
}
157+
158+
file_size = static_cast<size_t>(size.QuadPart);
159+
160+
HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL);
161+
162+
if (mapping_handle == NULL) {
163+
CloseHandle(file_handle);
164+
return nullptr;
165+
}
166+
167+
mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size);
168+
169+
if (mapped_data == NULL) {
170+
CloseHandle(mapping_handle);
171+
CloseHandle(file_handle);
172+
return nullptr;
173+
}
174+
175+
return std::make_shared<MmapWrapperImpl>(mapped_data, file_size, file_handle, mapping_handle);
176+
}
177+
112178
#else // Unix
113179
#include <dirent.h>
180+
#include <fcntl.h>
181+
#include <sys/mman.h>
114182
#include <sys/stat.h>
183+
#include <unistd.h>
115184

116185
bool file_exists(const std::string& filename) {
117186
struct stat buffer;
@@ -143,8 +212,66 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
143212
return "";
144213
}
145214

215+
class MmapWrapperImpl : public MmapWrapper {
216+
public:
217+
MmapWrapperImpl(void* data, size_t size) : MmapWrapper(data, size) {}
218+
219+
~MmapWrapperImpl() override {
220+
if (data_) {
221+
munmap(data_, size_);
222+
}
223+
}
224+
};
225+
226+
std::shared_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
227+
228+
int file_descriptor = open(filename.c_str(), O_RDONLY);
229+
if (file_descriptor == -1) {
230+
return nullptr;
231+
}
232+
233+
int mmap_flags = MAP_PRIVATE;
234+
235+
#ifdef __linux__
236+
// performance flags used by llama.cpp
237+
//posix_fadvise(file_descriptor, 0, 0, POSIX_FADV_SEQUENTIAL);
238+
//mmap_flags |= MAP_POPULATE;
239+
#endif
240+
241+
struct stat sb;
242+
if (fstat(file_descriptor, &sb) == -1) {
243+
close(file_descriptor);
244+
return nullptr;
245+
}
246+
247+
size_t file_size = sb.st_size;
248+
249+
void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0);
250+
251+
close(file_descriptor);
252+
253+
if (mapped_data == MAP_FAILED) {
254+
return nullptr;
255+
}
256+
257+
#ifdef __linux__
258+
// performance flags used by llama.cpp
259+
//posix_madvise(mapped_data, file_size, POSIX_MADV_WILLNEED);
260+
#endif
261+
262+
return std::make_shared<MmapWrapperImpl>(mapped_data, file_size);
263+
}
264+
146265
#endif
147266

267+
bool MmapWrapper::copy_data(void* buf, size_t n, size_t offset) const {
268+
if (offset >= size_ || n > (size_ - offset)) {
269+
return false;
270+
}
271+
std::memcpy(buf, data() + offset, n);
272+
return true;
273+
}
274+
148275
// get_num_physical_cores is copy from
149276
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
150277
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE

util.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define __UTIL_H__
33

44
#include <cstdint>
5+
#include <memory>
56
#include <string>
67
#include <vector>
78

@@ -44,6 +45,28 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
4445

4546
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
4647

48+
class MmapWrapper {
49+
public:
50+
static std::shared_ptr<MmapWrapper> create(const std::string& filename);
51+
52+
virtual ~MmapWrapper() = default;
53+
54+
MmapWrapper(const MmapWrapper&) = delete;
55+
MmapWrapper& operator=(const MmapWrapper&) = delete;
56+
MmapWrapper(MmapWrapper&&) = delete;
57+
MmapWrapper& operator=(MmapWrapper&&) = delete;
58+
59+
const uint8_t* data() const { return static_cast<uint8_t*>(data_); }
60+
size_t size() const { return size_; }
61+
bool copy_data(void* buf, size_t n, size_t offset) const;
62+
63+
protected:
64+
MmapWrapper(void* data, size_t size)
65+
: data_(data), size_(size) {}
66+
void* data_ = nullptr;
67+
size_t size_ = 0;
68+
};
69+
4770
std::string path_join(const std::string& p1, const std::string& p2);
4871
std::vector<std::string> split_string(const std::string& str, char delimiter);
4972
void pretty_progress(int step, int steps, float time);

0 commit comments

Comments
 (0)