Skip to content

Commit 2a5781b

Browse files
committed
server: move msg diffs tracking to HTTP thread
1 parent a87d9a4 commit 2a5781b

File tree

3 files changed

+60
-32
lines changed

3 files changed

+60
-32
lines changed

tools/server/server-context.cpp

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ struct server_slot {
101101
std::string generated_text;
102102
llama_tokens generated_tokens;
103103

104-
common_chat_msg chat_msg;
105-
106104
std::vector<completion_token_output> generated_token_probs;
107105

108106
bool has_next_token = true;
@@ -153,9 +151,6 @@ struct server_slot {
153151

154152
llama_token sampled;
155153

156-
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
157-
std::vector<std::string> generated_tool_call_ids;
158-
159154
// stats
160155
size_t n_sent_text = 0; // number of sent text character
161156

@@ -183,13 +178,10 @@ struct server_slot {
183178
stop = STOP_TYPE_NONE;
184179
stopping_word = "";
185180
n_sent_text = 0;
186-
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
187181

188182
generated_tokens.clear();
189183
generated_token_probs.clear();
190-
chat_msg = {};
191184
json_schema = json();
192-
generated_tool_call_ids.clear();
193185

194186
// clear speculative decoding stats
195187
n_draft_total = 0;
@@ -302,23 +294,6 @@ struct server_slot {
302294
return timings;
303295
}
304296

305-
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
306-
GGML_ASSERT(task);
307-
308-
auto previous_msg = chat_msg;
309-
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
310-
auto new_msg = common_chat_parse(
311-
generated_text,
312-
/* is_partial= */ stop != STOP_TYPE_EOS,
313-
task->params.oaicompat_chat_syntax);
314-
if (!new_msg.empty()) {
315-
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
316-
chat_msg = new_msg;
317-
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
318-
}
319-
return chat_msg;
320-
}
321-
322297
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
323298
GGML_ASSERT(task);
324299

@@ -1284,8 +1259,6 @@ struct server_context_impl {
12841259
} else {
12851260
res->content = tkn.text_to_send;
12861261
res->tokens = { tkn.tok };
1287-
1288-
slot.update_chat_msg(res->oaicompat_msg_diffs);
12891262
}
12901263

12911264
res->n_decoded = slot.n_decoded;
@@ -1338,7 +1311,6 @@ struct server_context_impl {
13381311
res->res_type = slot.task->params.res_type;
13391312
res->oaicompat_model = slot.task->params.oaicompat_model;
13401313
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
1341-
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
13421314

13431315
// populate res.probs_output
13441316
if (slot.task->params.sampling.n_probs > 0) {
@@ -2593,6 +2565,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
25932565
auto completion_id = gen_chatcmplid();
25942566
auto & rd = res->rd;
25952567

2568+
// tracking generation state and partial tool calls
2569+
std::vector<task_result_state> states;
2570+
25962571
try {
25972572
std::vector<server_task> tasks;
25982573

@@ -2630,6 +2605,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26302605
task.params.oaicompat_model = ctx_server.model_name;
26312606

26322607
tasks.push_back(std::move(task));
2608+
states.emplace_back(task.params.oaicompat_chat_syntax);
26332609
}
26342610

26352611
rd.post_tasks(std::move(tasks));
@@ -2652,6 +2628,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26522628
json arr = json::array();
26532629
for (auto & res : all_results.results) {
26542630
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
2631+
res->update(states[res->get_index()]); // update generation state
26552632
arr.push_back(res->to_json());
26562633
}
26572634
// if single request, return single object instead of array
@@ -2673,6 +2650,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26732650
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
26742651
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
26752652
);
2653+
first_result->update(states[first_result->get_index()]); // update generation state
26762654
}
26772655

26782656
// next responses are streamed
@@ -2683,7 +2661,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26832661
}
26842662
res->status = 200;
26852663
res->content_type = "text/event-stream";
2686-
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
2664+
res->next = [res_this = res.get(), res_type, &should_stop, states = std::move(states)](std::string & output) mutable -> bool {
26872665
if (should_stop()) {
26882666
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
26892667
return false; // should_stop condition met
@@ -2737,6 +2715,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
27372715
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
27382716
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
27392717
);
2718+
result->update(states[result->get_index()]); // update generation state
27402719
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
27412720
output = format_anthropic_sse(res_json);
27422721
} else {

tools/server/server-task.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
700700
return res;
701701
}
702702

703+
common_chat_msg task_result_state::update_chat_msg(
704+
const std::string & text_added,
705+
bool is_partial,
706+
std::vector<common_chat_msg_diff> & diffs) {
707+
generated_text += text_added;
708+
auto msg_prv_copy = chat_msg;
709+
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
710+
auto new_msg = common_chat_parse(
711+
generated_text,
712+
is_partial,
713+
oaicompat_chat_syntax);
714+
if (!new_msg.empty()) {
715+
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
716+
chat_msg = new_msg;
717+
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
718+
}
719+
return chat_msg;
720+
}
721+
703722
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
704723
std::time_t t = std::time(0);
705724
std::string finish_reason = "length";

tools/server/server-task.h

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,25 @@ struct result_prompt_progress {
161161
json to_json() const;
162162
};
163163

164+
// struct for tracking the state of a task (e.g., for streaming)
165+
struct task_result_state {
166+
// tracking diffs for partial tool calls
167+
std::vector<common_chat_msg_diff> diffs;
168+
common_chat_syntax oaicompat_chat_syntax;
169+
common_chat_msg chat_msg;
170+
std::string generated_text; // append new chunks of generated text here
171+
std::vector<std::string> generated_tool_call_ids;
172+
173+
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
174+
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
175+
176+
// parse partial tool calls and update the internal state
177+
common_chat_msg update_chat_msg(
178+
const std::string & text_added,
179+
bool is_partial,
180+
std::vector<common_chat_msg_diff> & diffs);
181+
};
182+
164183
struct server_task_result {
165184
int id = -1;
166185
int id_slot = -1;
@@ -175,6 +194,9 @@ struct server_task_result {
175194
virtual int get_index() {
176195
return -1;
177196
}
197+
virtual void update(task_result_state &) {
198+
// only used by server_task_result_cmpl_*
199+
}
178200
virtual json to_json() = 0;
179201
virtual ~server_task_result() = default;
180202
};
@@ -233,9 +255,9 @@ struct server_task_result_cmpl_final : server_task_result {
233255
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
234256
std::string oaicompat_model;
235257
std::string oaicompat_cmpl_id;
236-
common_chat_msg oaicompat_msg;
258+
common_chat_msg oaicompat_msg; // to be populated by update()
237259

238-
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
260+
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
239261

240262
virtual int get_index() override {
241263
return index;
@@ -247,6 +269,10 @@ struct server_task_result_cmpl_final : server_task_result {
247269

248270
virtual json to_json() override;
249271

272+
virtual void update(task_result_state & state) override {
273+
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
274+
}
275+
250276
json to_json_non_oaicompat();
251277

252278
json to_json_oaicompat();
@@ -280,7 +306,7 @@ struct server_task_result_cmpl_partial : server_task_result {
280306
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
281307
std::string oaicompat_model;
282308
std::string oaicompat_cmpl_id;
283-
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
309+
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
284310

285311
virtual int get_index() override {
286312
return index;
@@ -292,6 +318,10 @@ struct server_task_result_cmpl_partial : server_task_result {
292318

293319
virtual json to_json() override;
294320

321+
virtual void update(task_result_state & state) override {
322+
state.update_chat_msg(content, true, oaicompat_msg_diffs);
323+
}
324+
295325
json to_json_non_oaicompat();
296326

297327
json to_json_oaicompat();

0 commit comments

Comments
 (0)