diff --git a/common/common.cpp b/common/common.cpp index d4e8c7405eb..acf2ec841d7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1078,6 +1078,8 @@ struct common_init_result::impl { impl() = default; ~impl() = default; + // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top + llama_model_ptr model; llama_context_ptr context; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8786d4ee3e0..015ebae71d6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -459,23 +459,22 @@ llama_context::llama_context( } llama_context::~llama_context() { - // FIXME this currently results in a use-after-free bug if the model is freed before the context - // if (!model.hparams.no_alloc) { - // for (size_t i = 0; i < backend_ptrs.size(); ++i) { - // ggml_backend_t backend = backend_ptrs[i]; - // ggml_backend_buffer_type_t buft = backend_buft[i]; - - // const size_t size_exp = backend_buf_exp_size[i]; - // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - // if (size_exp == size_act) { - // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } else { - // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } - // } - // } + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } else { + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } + } + } ggml_opt_free(opt_ctx); } diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp index 566b039a070..34746c200ca 100644 --- a/tests/test-grammar-llguidance.cpp +++ b/tests/test-grammar-llguidance.cpp @@ -1196,6 +1196,9 @@ int main(int argc, const char ** argv) { test_sampler_chain(); + llama_free(ctx); + llama_model_free(model); + fprintf(stdout, "All tests passed.\n"); return 0; } diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index b183da47f3c..505dbfdb93d 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -146,8 +146,8 @@ int main(int argc, char **argv) { } } - llama_model_free(model); llama_free(ctx); + llama_model_free(model); llama_backend_free(); diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 2032a386bb4..0f627c5ff65 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -55,6 +55,7 @@ int main(int argc, char ** argv) { if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + llama_model_free(model); return 1; } @@ -108,6 +109,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } @@ -147,6 +150,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } @@ -165,6 +170,8 @@ int main(int argc, char ** argv) { common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true); if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } llama_memory_seq_rm(mem, 0, pp, -1); @@ -184,6 +191,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } @@ -200,6 +209,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 0be6ed69483..b431c7f31bf 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -2102,6 +2102,8 @@ int main(int argc, char ** argv) { struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads); if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) { fprintf(stderr, "%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str()); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } tpp.strict_cpu = t.cpu_strict; @@ -2111,6 +2113,8 @@ int main(int argc, char ** argv) { struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp); if (!threadpool) { fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } @@ -2126,6 +2130,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2136,6 +2142,8 @@ int main(int argc, char ** argv) { bool res = test_gen(ctx, 1, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2164,6 +2172,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run depth\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } @@ -2189,6 +2199,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run prompt\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2200,6 +2212,8 @@ int main(int argc, char ** argv) { bool res = test_gen(ctx, t.n_gen, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run gen\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } }