diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..fe1261e20 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -17,6 +17,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.serve.model_worker import ( logger, @@ -66,108 +67,117 @@ def __init__( async def generate_stream(self, params): self.call_ct += 1 - - context = params.pop("prompt") - request_id = params.pop("request_id") - temperature = float(params.get("temperature", 1.0)) - top_p = float(params.get("top_p", 1.0)) - top_k = params.get("top_k", -1.0) - presence_penalty = float(params.get("presence_penalty", 0.0)) - frequency_penalty = float(params.get("frequency_penalty", 0.0)) - max_new_tokens = params.get("max_new_tokens", 256) - stop_str = params.get("stop", None) - stop_token_ids = params.get("stop_token_ids", None) or [] - if self.tokenizer.eos_token_id is not None: - stop_token_ids.append(self.tokenizer.eos_token_id) - echo = params.get("echo", True) - use_beam_search = params.get("use_beam_search", False) - best_of = params.get("best_of", None) - - request = params.get("request", None) - - # Handle stop_str - stop = set() - if isinstance(stop_str, str) and stop_str != "": - stop.add(stop_str) - elif isinstance(stop_str, list) and stop_str != []: - stop.update(stop_str) - - for tid in stop_token_ids: - if tid is not None: - s = self.tokenizer.decode(tid) - if s != "": - stop.add(s) - - # make sampling params in vllm - top_p = max(top_p, 1e-5) - if temperature <= 1e-5: - top_p = 1.0 - - sampling_params = SamplingParams( - n=1, - temperature=temperature, - top_p=top_p, - use_beam_search=use_beam_search, - stop=list(stop), - stop_token_ids=stop_token_ids, - max_tokens=max_new_tokens, - top_k=top_k, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - best_of=best_of, - ) - results_generator = engine.generate(context, sampling_params, request_id) - - async for request_output in results_generator: - prompt = request_output.prompt - if echo: - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] - else: - text_outputs = [output.text for output in request_output.outputs] - text_outputs = " ".join(text_outputs) - - partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) - # prevent yielding partial stop sequence - if partial_stop: - continue - - aborted = False - if request and await request.is_disconnected(): - await engine.abort(request_id) - request_output.finished = True - aborted = True - for output in request_output.outputs: - output.finish_reason = "abort" - - prompt_tokens = len(request_output.prompt_token_ids) - completion_tokens = sum( - len(output.token_ids) for output in request_output.outputs + try: + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + stop_token_ids=stop_token_ids, + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + best_of=best_of, ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + aborted = False + if request and await request.is_disconnected(): + await engine.abort(request_id) + request_output.finished = True + aborted = True + for output in request_output.outputs: + output.finish_reason = "abort" + + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + ret = { + "text": text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": [ + output.cumulative_logprob for output in request_output.outputs + ], + "finish_reason": ( + request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs] + ), + } + # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. + # This aligns with the behavior of model_worker. + if request_output.finished: + yield ( + json.dumps({**ret, **{"finish_reason": None}}) + "\0" + ).encode() + yield (json.dumps(ret) + "\0").encode() + + if aborted: + break + except ValueError as e: ret = { - "text": text_outputs, - "error_code": 0, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - "cumulative_logprob": [ - output.cumulative_logprob for output in request_output.outputs - ], - "finish_reason": request_output.outputs[0].finish_reason - if len(request_output.outputs) == 1 - else [output.finish_reason for output in request_output.outputs], + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, } - # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. - # This aligns with the behavior of model_worker. - if request_output.finished: - yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode() - yield (json.dumps(ret) + "\0").encode() - - if aborted: - break + yield json.dumps(ret).encode() + b"\0" async def generate(self, params): async for x in self.generate_stream(params):