@@ -1118,8 +1118,27 @@ ggml_tensor *GLMBlock::forward(ModelContext *mctx, ggml_tensor *hidden_states, g
11181118 return output;
11191119}
11201120
1121+ static void alloc_weight_context (ModelContext *mctx, const ggml_backend_buffer_t sd_buf) {
1122+ void *sd_buf_base = ggml_backend_buffer_get_base (sd_buf);
1123+ const size_t sd_buf_size = ggml_backend_buffer_get_size (sd_buf);
1124+ if (ggml_backend_is_cpu (mctx->backend .get ())) {
1125+ mctx->buf_w = unique_ggml_backend_buffer_t (ggml_backend_cpu_buffer_from_ptr (sd_buf_base, sd_buf_size));
1126+ }
1127+ #ifdef GGML_USE_METAL
1128+ else if (ggml_backend_is_metal (mctx->backend .get ())) {
1129+ const size_t max_size = ggml_get_max_tensor_size (mctx->ctx_w .get ());
1130+ mctx->buf_w =
1131+ unique_ggml_backend_buffer_t (ggml_backend_metal_buffer_from_ptr (sd_buf_base, sd_buf_size, max_size));
1132+ }
1133+ #endif
1134+ else {
1135+ mctx->buf_w =
1136+ unique_ggml_backend_buffer_t (ggml_backend_alloc_ctx_tensors (mctx->ctx_w .get (), mctx->backend .get ()));
1137+ }
1138+ }
1139+
11211140void ChatGLMForCausalLM::load_state_dict (const StateDict &sd) {
1122- alloc_weight_context (sd.buf .get ());
1141+ alloc_weight_context (mctx_. get (), sd.buf .get ());
11231142
11241143 StateDict self_sd = state_dict ();
11251144 for (auto &item : self_sd.kv ) {
@@ -1259,7 +1278,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
12591278}
12601279
12611280void ChatGLM2ForCausalLM::load_state_dict (const StateDict &sd) {
1262- alloc_weight_context (sd.buf .get ());
1281+ alloc_weight_context (mctx_. get (), sd.buf .get ());
12631282
12641283 if (config.num_virtual_tokens > 0 ) {
12651284 ggml_tensor *past_key_values = sd.kv .at (" past_key_values" );
@@ -1959,7 +1978,7 @@ int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const
19591978}
19601979
19611980void ChatGLM4VForCausalLM::load_state_dict (const StateDict &sd) {
1962- alloc_weight_context (sd.buf .get ());
1981+ alloc_weight_context (mctx_. get (), sd.buf .get ());
19631982
19641983 auto self_sd = state_dict ();
19651984 ChatGLM2ForCausalLM::load_state_dict (mctx_.get (), self_sd, sd);
0 commit comments