From 482f03102f6e6666639b134a94482436af3ae2dd Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Fri, 28 Feb 2025 13:49:29 -0800 Subject: [PATCH 1/7] initial fix_tests fix max tokens docker docker fix fix fix fix fix fix fix fix fix docker fix docker fix docker fix docker fix docker fix revert port change revert port change --- README.md | 53 +- src/triton_cli/.gitignore | 3 + src/triton_cli/common.py | 2 +- src/triton_cli/repository.py | 34 +- src/triton_cli/server/server_factory.py | 5 +- src/triton_cli/server/server_utils.py | 83 +++ src/triton_cli/templates/llmapi/1/model.json | 28 + src/triton_cli/templates/llmapi/1/model.py | 703 +++++++++++++++++++ src/triton_cli/templates/llmapi/config.pbtxt | 34 + tests/conftest.py | 24 + tests/test_cli.py | 12 + tests/test_e2e.py | 53 ++ tests/test_model_repository.py | 20 +- 13 files changed, 1047 insertions(+), 7 deletions(-) create mode 100644 src/triton_cli/templates/llmapi/1/model.json create mode 100644 src/triton_cli/templates/llmapi/1/model.py create mode 100644 src/triton_cli/templates/llmapi/config.pbtxt diff --git a/README.md b/README.md index 9d47fd1..c0d7c96 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Server. ## Table of Contents -| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) | +| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Serving a HuggingFace LLM Model with LLM API](#serving-a-huggingface-llm-model-with-llm-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) | ## Pre-requisites @@ -351,7 +351,56 @@ triton start --frontend openai # Interact with model at http://localhost:9000 curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ "model": "llama-3.1-8b-instruct", - "messages": [{"role": "user", "content": "What is machine learning?"}] + "messages": [{"role": "user", "content": "What is machine learning?"}], + "max_tokens": 256 +}' + +# Profile model with GenAI-Perf +triton profile -m llama-3.1-8b-instruct --service-kind openai --endpoint-type chat --url localhost:9000 --streaming +``` + +## Serving a HuggingFace LLM Model with LLM API + +> [!NOTE] +> LLM API has not yet been integrated into the official triton server tensorrt_llm backend image yet. +> To start the LLM API functionality, the user will only + +The LLM API is a high-level Python API and designed for Tensorrt LLM workflows. It could +convert a LLM model in Hugging Face format into a Tensorrt LLM engine and serve the engine with a unified Python API without invoking different +engine build and converting scripts. +To use the LLM API with Triton CLI, import the model with `--backend llmapi` +```bash +export MODEL_NAME="llama-3.1-8b-instruct" +export HF_ID="meta-llama/Llama-3.1-8B-Instruct" +triton import -m $MODEL_NAME --source "hf:$HF_ID" --backend llmapi +``` + +Huggingface models will be downloaded at runtime when starting the LLM API engine if not found +locally in the HuggingFace cache. No offline engine building step is required, +but you can pre-download the model in advance to avoid downloading at server +startup time. tensorrt_llm>=0.18.0 is required. + +#### Example + +```bash +# Install the Triton CLI +pip install git+https://github.com/triton-inference-server/triton_cli.git@main + +# Authenticate with huggingface for restricted models like Llama-2 and Llama-3 +huggingface-cli login + +# Build TRT LLM engine and generate a Triton model repository pointing at it +triton remove -m all +triton import -m llama-3.1-8b-instruct --backend llmapi + +# Start Triton pointing at the default model repository +triton start --frontend openai --mode docker + +# Interact with model at http://localhost:9000 +curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ + "model": "llama-3.1-8b-instruct", + "messages": [{"role": "user", "content": "What is machine learning?"}], + "max_tokens": 256 }' # Profile model with GenAI-Perf diff --git a/src/triton_cli/.gitignore b/src/triton_cli/.gitignore index 973a71d..d40f318 100644 --- a/src/triton_cli/.gitignore +++ b/src/triton_cli/.gitignore @@ -1,2 +1,5 @@ *.json *.cache + +# Except model.json from the llmapi template +!templates/llmapi/1/model.json diff --git a/src/triton_cli/common.py b/src/triton_cli/common.py index 8c126af..0f1f46a 100755 --- a/src/triton_cli/common.py +++ b/src/triton_cli/common.py @@ -55,4 +55,4 @@ class TritonCLIException(Exception): DEFAULT_MODEL_REPO: Path = Path.home() / "models" DEFAULT_HF_CACHE: Path = Path.home() / ".cache" / "huggingface" HF_CACHE: Path = Path(os.environ.get("HF_HOME", DEFAULT_HF_CACHE)) -SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm"} +SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm", "llmapi"} diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 7728d55..254ecf3 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -67,7 +67,9 @@ SOURCE_PREFIX_NGC = "ngc:" SOURCE_PREFIX_LOCAL = "local:" -TRT_TEMPLATES_PATH = Path(__file__).parent / "templates" / "trt_llm" +TEMPLATES_PATH = Path(__file__).parent / "templates" +TRTLLM_TEMPLATES_PATH = TEMPLATES_PATH / "trt_llm" +LLMAPI_TEMPLATES_PATH = TEMPLATES_PATH / "llmapi" # Support changing destination dynamically to point at # pre-downloaded checkpoints in various circumstances @@ -266,6 +268,8 @@ def __add_huggingface_model( self.remove(model, verbose=False) # Let detailed traceback be reported for TRT-LLM errors for debugging raise e + elif backend == "llmapi": + self.__generate_llmapi_model(version_dir, huggingface_id) else: # TODO: Add generic support for HuggingFace models with HF API. # For now, use vLLM as a means of deploying HuggingFace Transformers @@ -322,6 +326,21 @@ def __generate_vllm_model(self, huggingface_id: str): model_files = {"model.json": model_contents} return model_config, model_files + def __generate_llmapi_model(self, version_dir, huggingface_id: str): + # load the model.json from llmapi template + model_config_file = version_dir / "model.json" + with open(model_config_file) as f: + model_config_str = f.read() + + model_config_json = json.loads(model_config_str) + + # change the model id as the huggingface_id + model_config_json["model"] = huggingface_id + + # write back the model.json + with open(model_config_file, "w") as f: + f.write(json.dumps(model_config_json)) + def __generate_ngc_model(self, name: str, source: str): engines_path = ENGINE_DEST_PATH + "/" + source parse_and_substitute( @@ -392,7 +411,7 @@ def __create_model_repository( ) shutil.copytree( - TRT_TEMPLATES_PATH, + TRTLLM_TEMPLATES_PATH, self.repo, dirs_exist_ok=True, ignore=shutil.ignore_patterns("__pycache__"), @@ -402,6 +421,17 @@ def __create_model_repository( logger.debug(f"Adding TensorRT-LLM models at: {self.repo}") else: version_dir.mkdir(parents=True, exist_ok=False) + if backend == "llmapi": + shutil.copytree( + LLMAPI_TEMPLATES_PATH / "1", + version_dir, + dirs_exist_ok=True, + ignore=shutil.ignore_patterns("__pycache__"), + ) + shutil.copy( + LLMAPI_TEMPLATES_PATH / "config.pbtxt", + model_dir, + ) logger.debug(f"Adding new model to repo at: {version_dir}") except FileExistsError: logger.warning(f"Overwriting existing model in repo at: {version_dir}") diff --git a/src/triton_cli/server/server_factory.py b/src/triton_cli/server/server_factory.py index 921f5f7..c42df23 100755 --- a/src/triton_cli/server/server_factory.py +++ b/src/triton_cli/server/server_factory.py @@ -25,7 +25,7 @@ LOGGER_NAME, TritonCLIException, ) -from .server_utils import TRTLLMUtils, VLLMUtils +from .server_utils import TRTLLMUtils, VLLMUtils, LLMAPIUtils logger = logging.getLogger(LOGGER_NAME) @@ -198,11 +198,14 @@ def _get_openai_chat_template_tokenizer(config): ) trtllm_utils = TRTLLMUtils(config.model_repository) vllm_utils = VLLMUtils(config.model_repository) + llmapi_utils = LLMAPIUtils(config.model_repository) if trtllm_utils.has_trtllm_model(): tokenizer_path = trtllm_utils.get_engine_path() elif vllm_utils.has_vllm_model(): tokenizer_path = vllm_utils.get_vllm_model_huggingface_id_or_path() + elif llmapi_utils.has_llmapi_model(): + tokenizer_path = llmapi_utils.get_llmapi_model_huggingface_id_or_path() else: raise TritonCLIException( "Unable to find a tokenizer to start the Triton OpenAI RESTful API, please use '--openai-chat-template-tokenizer' to specify a tokenizer." diff --git a/src/triton_cli/server/server_utils.py b/src/triton_cli/server/server_utils.py index 98823fe..97d7477 100755 --- a/src/triton_cli/server/server_utils.py +++ b/src/triton_cli/server/server_utils.py @@ -304,3 +304,86 @@ def _find_vllm_model_huggingface_id_or_path(self) -> str: return model_id except OSError: raise Exception(f"Unable to open {model_config_json_file}") + + +class LLMAPIUtils: + """ + A utility class for handling LLMAPI specific models. + """ + + def __init__(self, model_path: Path): + self._model_repo_path = model_path + self._llmapi_model_path = self._find_llmapi_model_path() + self._is_llmapi_model = self._llmapi_model_path is not None + + def has_llmapi_model(self) -> bool: + """ + Returns + ------- + A boolean indicating whether a LLMAPI model exists in the model repo + """ + return self._is_llmapi_model + + def get_llmapi_model_huggingface_id_or_path(self) -> str: + """ + Returns + ------- + The LLMAPI model's Huggingface Id or path + """ + return self._find_llmapi_model_huggingface_id_or_path() + + def _find_llmapi_model_path(self) -> Path: + """ + Returns + ------- + A pathlib.Path object containing the path to the LLMAPI model folder. + Assumptions + ---------- + - Assumes only a single model uses the LLMAPI backend (could have multiple models) + """ + # Search the llmapi model from all models in model repository + model_dirs = [ + model_dir + for model_dir in self._model_repo_path.iterdir() + if model_dir.is_dir() + ] + for model_dir in model_dirs: + model_config_file = Path(self._model_repo_path) / model_dir / "config.pbtxt" + model_json_path = model_config_file.parent / "1" / "model.json" + # check if config.pbtxt exists + if model_config_file.is_file(): + # read the config.pbtxt file and identify the backend + with open(model_config_file) as config_file: + config = text_format.Parse(config_file.read(), mc.ModelConfig()) + json_config = json.loads( + json_format.MessageToJson( + config, preserving_proto_field_name=True + ) + ) + # check if the model.json also exists. + if json_config["backend"] == "python" and model_json_path.is_file(): + return model_config_file.parent + + return None + + def _find_llmapi_model_huggingface_id_or_path(self) -> str: + """ + Returns + ------- + The llmapi model's Huggingface Id or path + """ + assert self._is_llmapi_model, "model Huggingface Id or path cannot be parsed from a model repository that does not contain a LLMAPI model." + try: + # assume the version is always "1" + model_version_path = self._llmapi_model_path / "1" + model_config_json_file = model_version_path / "model.json" + with open(model_config_json_file) as json_data: + data = json.load(json_data) + model_id = data.get("model") + if not model_id: + raise Exception( + f"Unable to parse config from {model_config_json_file}" + ) + return model_id + except OSError: + raise Exception(f"Unable to open {model_config_json_file}") diff --git a/src/triton_cli/templates/llmapi/1/model.json b/src/triton_cli/templates/llmapi/1/model.json new file mode 100644 index 0000000..d659012 --- /dev/null +++ b/src/triton_cli/templates/llmapi/1/model.json @@ -0,0 +1,28 @@ +{ + "max_batch_size": 64, + "decoupled": true, + + "model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tokenizer": null, + "tokenizer_mode": null, + "skip_tokenizer_init": null, + "trust_remote_code": null, + "tensor_parallel_size": null, + "pipeline_parallel_size": null, + "dtype": null, + "revision": null, + "tokenizer_revision": null, + "speculative_model": null, + "enable_chunked_prefill": null, + "num_instances": null, + + "use_cuda_graph": null, + "cuda_graph_batch_sizes": null, + "cuda_graph_max_batch_size": null, + "cuda_graph_padding_enabled": null, + "enable_overlap_scheduler": null, + "kv_cache_dtype": null, + "torch_compile_enabled": null, + "torch_compile_fullgraph": null, + "torch_compile_inductor_enabled": null +} diff --git a/src/triton_cli/templates/llmapi/1/model.py b/src/triton_cli/templates/llmapi/1/model.py new file mode 100644 index 0000000..85425a3 --- /dev/null +++ b/src/triton_cli/templates/llmapi/1/model.py @@ -0,0 +1,703 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import asyncio +import gc +import json +import os +import queue +import threading +from contextlib import asynccontextmanager + +import numpy as np +import triton_python_backend_utils as pb_utils +from mpi4py.futures import MPICommExecutor +from mpi4py.MPI import COMM_WORLD + +from tensorrt_llm import SamplingParams +from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig +from tensorrt_llm._utils import global_mpi_rank +from tensorrt_llm.llmapi import LLM + +_TRTLLM_ENGINE_ARGS_FILENAME = "model.json" + + +class TritonPythonModel: + # Define the expected keys for each config + # TODO: Add more keys as needed + PYTORCH_CONFIG_KEYS = { + "use_cuda_graph", + "cuda_graph_batch_sizes", + "cuda_graph_max_batch_size", + "cuda_graph_padding_enabled", + "enable_overlap_scheduler", + "kv_cache_dtype", + "torch_compile_enabled", + "torch_compile_fullgraph", + "torch_compile_inductor_enabled", + } + + LLM_ENGINE_KEYS = { + "model", + "tokenizer", + "tokenizer_mode", + "skip_tokenizer_init", + "trust_remote_code", + "tensor_parallel_size", + "pipeline_parallel_size", + "dtype", + "revision", + "tokenizer_revision", + "speculative_model", + "enable_chunked_prefill", + } + + def _get_input_scalar_by_name(self, request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return None + + tensor = tensor.as_numpy() + if tensor.size == 0: + return None + + return tensor.item(0) + + def _get_string_list_by_name(self, request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return None + + tensor = tensor.as_numpy() + if tensor.size == 0: + return None + + # Convert to list and handle bytes conversion + if isinstance(tensor, np.ndarray): + if tensor.ndim == 0: + item = tensor.item() + return [item.decode("utf-8") if isinstance(item, bytes) else str(item)] + + return [ + item.decode("utf-8") if isinstance(item, bytes) else str(item) + for item in tensor.flatten() + ] + + # Fallback case + if isinstance(tensor, bytes): + return [tensor.decode("utf-8")] + return [str(tensor)] + + def _get_sampling_config_from_request(self, request): + # TODO: Add more sampling parameters as needed + kwargs = { + "beam_width": self._get_input_scalar_by_name(request, "beam_width") or 1, + "temperature": self._get_input_scalar_by_name(request, "temperature"), + "top_k": self._get_input_scalar_by_name(request, "top_k"), + "top_p": self._get_input_scalar_by_name(request, "top_p"), + "frequency_penalty": self._get_input_scalar_by_name( + request, "frequency_penalty" + ), + "presence_penalty": self._get_input_scalar_by_name( + request, "presence_penalty" + ), + "max_tokens": self._get_input_scalar_by_name(request, "max_tokens"), + # stop_words is deprecated. Should use stop instead. + "stop": ( + self._get_string_list_by_name(request, "stop") + or self._get_string_list_by_name(request, "stop_words") + ), + # random_seed is deprecated. Should use seed instead. + "seed": ( + self._get_input_scalar_by_name(request, "seed") + or self._get_input_scalar_by_name(request, "random_seed") + ), + } + + # Adjust top_p if it's not valid + kwargs["top_p"] = ( + None if kwargs["top_p"] is None or kwargs["top_p"] <= 0 else kwargs["top_p"] + ) + + return kwargs + + @classmethod + def auto_complete_config(cls, auto_complete_model_config): + # Add inputs/outputs to the model config. + cls._auto_complete_inputs_and_outputs(auto_complete_model_config) + + # Get the max batch size and decoupled model transaction policy from the json file. + engine_args_filepath = os.path.join( + pb_utils.get_model_dir(), _TRTLLM_ENGINE_ARGS_FILENAME + ) + assert os.path.isfile( + engine_args_filepath + ), f"'{_TRTLLM_ENGINE_ARGS_FILENAME}' containing TRT-LLM engine args must be provided in '{pb_utils.get_model_dir()}'" + with open(engine_args_filepath) as file: + # The Python interpreter used to invoke this function will be destroyed upon returning from this function and as a result none of the objects created here will be available in the initialize, execute, or finalize functions. + trtllm_engine_config = json.load(file) + + model_config_keys = {"max_batch_size", "decoupled"} + auto_complete_config = { + k: v for k, v in trtllm_engine_config.items() if k in model_config_keys + } + + # Set the max batch size and decoupled model transaction policy in the model config. + is_decoupled = auto_complete_config.get("decoupled", False) + auto_complete_model_config.set_model_transaction_policy( + dict(decoupled=is_decoupled) + ) + max_batch_size = auto_complete_config.get("max_batch_size", 64) + auto_complete_model_config.set_max_batch_size(int(max_batch_size)) + + return auto_complete_model_config + + @staticmethod + def _auto_complete_inputs_and_outputs(auto_complete_model_config): + # Inputs expected by the backend. + inputs = [ + {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, + { + "name": "stream", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "exclude_input_in_output", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_finish_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_stop_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "temperature", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "beam_width", + "data_type": "TYPE_INT32", + "dims": [1], + "optional": True, + }, + { + "name": "top_k", + "data_type": "TYPE_INT32", + "dims": [1], + "optional": True, + }, + { + "name": "top_p", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "frequency_penalty", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "presence_penalty", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "max_tokens", + "data_type": "TYPE_INT32", + "dims": [1], + }, + { + "name": "stop", + "data_type": "TYPE_STRING", + "dims": [-1], + "optional": True, + }, + { + # stop_words is deprecated. Should use stop instead. + "name": "stop_words", + "data_type": "TYPE_STRING", + "dims": [-1], + "optional": True, + }, + { + "name": "seed", + "data_type": "TYPE_UINT64", + "dims": [1], + "optional": True, + }, + { + # random_seed is deprecated. Should use seed instead. + "name": "random_seed", + "data_type": "TYPE_UINT64", + "dims": [1], + "optional": True, + }, + ] + # Outputs expected by the backend. + outputs = [ + {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "stop_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, + ] + + # Collect input and output names from the provided model config. + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + for input in config["input"]: + input_names.append(input["name"]) + for output in config["output"]: + output_names.append(output["name"]) + + # Add missing inputs and outputs to the model config. + for input in inputs: + if input["name"] not in input_names: + auto_complete_model_config.add_input(input) + for output in outputs: + if output["name"] not in output_names: + auto_complete_model_config.add_output(output) + + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = pb_utils.using_decoupled_model_transaction_policy( + self.model_config + ) + self.params = self.model_config["parameters"] + self.logger = pb_utils.Logger + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + if global_mpi_rank() == 0: + # Initialize engine arguments + self._init_engine_args() + self.logger.log_info( + f"[trtllm] rank{global_mpi_rank()} is starting trtllm engine with args: {self.llm_engine_args}" + ) + + # Starting the TRT-LLM engine with LLM API and its event thread running the AsyncIO event loop. + self._init_engine() + + # Starting the response thread. It allows TRT-LLM to keep making progress while + # response sender(s) are sending responses to server frontend. + self._response_queue = queue.Queue() + self._response_thread = threading.Thread(target=self._response_loop) + self._response_thread.start() + else: + self.logger.log_info( + f"[trtllm] rank{global_mpi_rank()} is waiting for the leader node..." + ) + with MPICommExecutor(COMM_WORLD) as executor: + if executor is not None: + raise RuntimeError( + f"[trtllm] rank{COMM_WORLD.rank} should not have executor" + ) + return + + def _get_llm_args(self, args_dict): + pytorch_config_args = { + k: v + for k, v in args_dict.items() + if k in self.PYTORCH_CONFIG_KEYS and v is not None + } + llm_engine_args = { + k: v + for k, v in args_dict.items() + if k in self.LLM_ENGINE_KEYS and v is not None + } + if "model" not in llm_engine_args: + raise pb_utils.TritonModelException( + "Model name is required in the TRT-LLM engine config." + ) + + return pytorch_config_args, llm_engine_args + + def _init_engine_args(self): + """Initialize engine arguments from config file.""" + engine_args_filepath = os.path.join( + pb_utils.get_model_dir(), _TRTLLM_ENGINE_ARGS_FILENAME + ) + if not os.path.isfile(engine_args_filepath): + raise pb_utils.TritonModelException( + f"'{_TRTLLM_ENGINE_ARGS_FILENAME}' containing TRT-LLM engine args must be provided in '{pb_utils.get_model_dir()}'" + ) + + try: + with open(engine_args_filepath) as file: + self.trtllm_engine_config = json.load(file) + except json.JSONDecodeError as e: + raise pb_utils.TritonModelException(f"Failed to parse engine config: {e}") + + self.pytorch_config_args, self.llm_engine_args = self._get_llm_args( + self.trtllm_engine_config + ) + + def _init_engine(self): + # Run the engine in a separate thread running the AsyncIO event loop. + self._llm_engine = None + self._llm_engine_start_cv = threading.Condition() + self._llm_engine_shutdown_event = asyncio.Event() + self._event_thread = threading.Thread( + target=asyncio.run, args=(self._run_llm_engine(),) + ) + self._event_thread.start() + with self._llm_engine_start_cv: + while self._llm_engine is None: + self._llm_engine_start_cv.wait() + + # The 'threading.Thread()' will not raise the exception here should the engine + # failed to start, so the exception is passed back via the engine variable. + if isinstance(self._llm_engine, Exception): + e = self._llm_engine + self.logger.log_error(f"[trtllm] Failed to start engine: {e}") + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + raise e + + async def _run_llm_engine(self): + # Counter to keep track of ongoing request counts. + self._ongoing_request_count = 0 + + @asynccontextmanager + async def async_llm_wrapper(): + # Create LLM in a thread to avoid blocking + loop = asyncio.get_running_loop() + try: + pytorch_config = PyTorchConfig(**self.pytorch_config_args) + llm = await loop.run_in_executor( + None, + lambda: LLM( + **self.llm_engine_args, + backend="pytorch", + pytorch_backend_config=pytorch_config, + ), + ) + yield llm + finally: + if "llm" in locals(): + # Run shutdown in a thread to avoid blocking + await loop.run_in_executor(None, llm.shutdown) + + try: + async with async_llm_wrapper() as engine: + # Capture the engine event loop and make it visible to other threads. + self._event_loop = asyncio.get_running_loop() + + # Signal the engine is started and make it visible to other threads. + with self._llm_engine_start_cv: + self._llm_engine = engine + self._llm_engine_start_cv.notify_all() + + # Wait for the engine shutdown signal. + await self._llm_engine_shutdown_event.wait() + + # Wait for the ongoing requests to complete. + while self._ongoing_request_count > 0: + self.logger.log_info( + "[trtllm] Awaiting remaining {} requests".format( + self._ongoing_request_count + ) + ) + await asyncio.sleep(1) + + # Cancel all tasks in the event loop. + for task in asyncio.all_tasks(loop=self._event_loop): + if task is not asyncio.current_task(): + task.cancel() + + except Exception as e: + # Signal and pass the exception back via the engine variable if the engine + # failed to start. If the engine has started, re-raise the exception. + with self._llm_engine_start_cv: + if self._llm_engine is None: + self._llm_engine = e + self._llm_engine_start_cv.notify_all() + return + raise e + + self._llm_engine = None + self.logger.log_info("[trtllm] Shutdown complete") + + def _response_loop(self): + while True: + item = self._response_queue.get() + # To signal shutdown a None item will be added to the queue. + if item is None: + break + response_state, response, response_flag = item + response_sender = response_state["response_sender"] + try: + response_sender.send(response, response_flag) + # Stop checking for cancellation if the last response is generated. + if not response_state["last_response_generated"]: + response_state["is_cancelled"] = response_sender.is_cancelled() + except Exception as e: + self.logger.log_error( + f"An error occurred while sending a response: {e}" + ) + finally: + if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: + self._ongoing_request_count -= 1 + + def execute(self, requests): + # TODO: Add health check here? + for request in requests: + # TODO : Verify Lora + if request is not None: + assert ( + self._llm_engine_shutdown_event.is_set() is False + ), "Cannot create tasks after shutdown has been requested" + coro = self._generate(request) + asyncio.run_coroutine_threadsafe(coro, self._event_loop) + + return None + + async def _generate(self, request): + response_sender = request.get_response_sender() + response_state = { + "response_sender": response_sender, + "is_cancelled": False, + "last_response_generated": False, # last response ready but not yet sent + } + self._ongoing_request_count += 1 + decrement_ongoing_request_count = True + try: + ( + prompt, + stream, + prepend_input, + sampling_config, + additional_outputs, + ) = self._get_input_tensors(request) + + sampling_params = SamplingParams(**sampling_config) + + # Generate the response. + response_iterator = self._llm_engine.generate_async( + prompt, sampling_params, streaming=stream + ) + + request_output_state = {} + async for request_output in response_iterator: + # TODO: Add request cancellation check here + # Send each response if streaming. + if stream: + response = self._create_response( + request_output_state, + request_output, + prepend_input=False, + additional_outputs=additional_outputs, + ) + flags = 0 + if request_output.finished: + response_state["last_response_generated"] = True + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait((response_state, response, flags)) + + # Send the last response which contains all the outputs if not streaming. + if not stream: + response_sender.send( + self._create_response( + request_output_state={}, + request_output=request_output, + prepend_input=prepend_input, + additional_outputs=additional_outputs, + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + + except Exception as e: + self.logger.log_error(f"[trtllm] Error generating stream: {e}") + error = pb_utils.TritonError(f"Error generating stream: {e}") + text_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) + ) + response = pb_utils.InferenceResponse( + output_tensors=[text_output_tensor], error=error + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + raise e + + finally: + if decrement_ongoing_request_count: + self._ongoing_request_count -= 1 + + def _get_input_tensors(self, request): + # Parse the prompt based on the batch size. + text_input = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy() + prompt = ( + text_input[0][0] + if self.model_config["max_batch_size"] > 0 + else text_input[0] + ) + + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # stream + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream and not self.decoupled: + raise pb_utils.TritonModelException( + "Streaming is only supported in decoupled mode." + ) + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + # Default to False if not specified + prepend_input = False + if prepend_input and stream: + raise pb_utils.TritonModelException( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # Sampling parameters + sampling_config = self._get_sampling_config_from_request(request) + + # additional outputs + additional_outputs = { + "return_finish_reason": None, + "return_stop_reason": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, sampling_config, additional_outputs + + def _create_response( + self, request_output_state, request_output, prepend_input, additional_outputs + ): + # TODO: Check if request_output has_error and handle it + output_tensors = [] + + # text_output + prepend_prompt = "" + if "prev_lens_text_output" not in request_output_state: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + request_output_state["prev_lens_text_output"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_text_output"] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_text_output"] = [ + len(output.text) for output in request_output.outputs + ] + + # finish_reason + if additional_outputs["return_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) + + # stop_reason + if additional_outputs["return_stop_reason"]: + stop_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "stop_reason", np.asarray(stop_reason, dtype=np.object_) + ) + ) + + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) + + def finalize(self): + self.logger.log_info("[trtllm] Issuing finalize to trtllm backend") + self._event_loop.call_soon_threadsafe(self._llm_engine_shutdown_event.set) + + # Shutdown the event thread. + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + + # # Shutdown the response thread. + self._response_queue.put(None) + if self._response_thread is not None: + self._response_thread.join() + self._response_thread = None + + # When using parallel tensors, the stub process may not shutdown due to + # unreleased references, so manually run the garbage collector once. + self.logger.log_info("[trtllm] Running Garbage Collector on finalize...") + gc.collect() + self.logger.log_info("[trtllm] Garbage Collector on finalize... done") diff --git a/src/triton_cli/templates/llmapi/config.pbtxt b/src/triton_cli/templates/llmapi/config.pbtxt new file mode 100644 index 0000000..65873aa --- /dev/null +++ b/src/triton_cli/templates/llmapi/config.pbtxt @@ -0,0 +1,34 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/tests/conftest.py b/tests/conftest.py index 7527be6..3f4ca44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,18 @@ def vllm_server(): server.stop() +@pytest.fixture(scope="function") +def llmapi_server(): + llm_repo = None + + # llmapi models might need to be downloaded on the fly during model loading, + # so leave longer timeout in case of slow network. + server = ScopedTritonServer(repo=llm_repo, timeout=1800) + yield server + # Ensure server is cleaned up after each test + server.stop() + + @pytest.fixture(scope="function") def trtllm_openai_server(): llm_repo = None @@ -92,6 +104,18 @@ def vllm_openai_server(): server.stop() +@pytest.fixture(scope="function") +def llmapi_openai_server(): + llm_repo = None + + # llmapi models might need to be downloaded on the fly during model loading, + # so leave longer timeout in case of slow network. + server = ScopedTritonServer(repo=llm_repo, timeout=1800, frontend="openai") + yield server + # Ensure server is cleaned up after each test + server.stop() + + @pytest.fixture(scope="function") def simple_server(): test_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 78abc37..14625e1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -38,6 +38,8 @@ CUSTOM_TRTLLM_MODEL_SOURCES = [("trtllm-model", "hf:gpt2")] +CUSTOM_LLMAPI_MODEL_SOURCES = [("llmapi-model", "hf:gpt2")] + # TODO: Add public NGC model for testing CUSTOM_NGC_MODEL_SOURCES = [("my-llm", "ngc:does-not-exist")] @@ -96,6 +98,16 @@ def test_repo_add_trtllm_build(self, model, source): TritonCommands._import(model, source=source, backend="tensorrtllm") TritonCommands._clear() + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRTLLM image" + ) + @pytest.mark.parametrize("model,source", CUSTOM_TRTLLM_MODEL_SOURCES) + def test_repo_add_llmapi_build(self, model, source): + # TODO: Parse repo to find TRT-LLM models and backend in config + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + TritonCommands._clear() + @pytest.mark.skip(reason="Pre-built TRT-LLM engines not available") def test_import_trtllm_prebuilt(self, model, source): # TODO: Parse repo to find TRT-LLM models and backend in config diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 2a491ca..88f0e70 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -145,6 +145,59 @@ def test_vllm_openai_e2e(self, vllm_openai_server): model, service_kind="openai", endpoint_type="chat", url="localhost:9000" ) + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", + reason="Only run for TRT-LLM image with LLM-API", + ) + @pytest.mark.parametrize( + "protocol", + [ + "grpc", + pytest.param( + "http", + # NOTE: skip because xfail was causing server to not get cleaned up by test in background + marks=pytest.mark.skip( + reason="http not supported decoupled models and model profiling yet" + ), + ), + ], + ) + @pytest.mark.timeout(LLM_TIMEOUT_SECS) + def test_llmapi_e2e(self, llmapi_server, protocol): + # NOTE: llmapi "backend" is using the same api format as tensorrtllm(tensorrt_llm_bls) backend, + # Use the tensorrtllm backend as the profiling option. + model = os.environ.get("TRTLLM_MODEL") + assert model is not None, "TRTLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + llmapi_server.start() + TritonCommands._infer(model, prompt=PROMPT, protocol=protocol) + # TODO: update this to llmapi when genai-perf supports the llmapi backend's api format + TritonCommands._profile(model, backend="tensorrtllm") + + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", + reason="Only run for TRT-LLM image with LLM-API", + ) + @pytest.mark.skipif( + os.environ.get("TRTLLM_MODEL") == "gpt2", + reason="gpt2's tokenizer doesn't have a chat template defined", + ) + @pytest.mark.timeout(LLM_TIMEOUT_SECS) + def test_llmapi_openai_e2e(self, vllm_openai_server): + model = os.environ.get("TRTLLM_MODEL") + assert model is not None, "TRTLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + vllm_openai_server.start() + TritonCommands._profile( + model, service_kind="openai", endpoint_type="chat", url="localhost:9000" + ) + @pytest.mark.skipif( os.environ.get("CI_PIPELINE") == "GITHUB_ACTIONS", reason="bandage/temporary fix", diff --git a/tests/test_model_repository.py b/tests/test_model_repository.py index 736b701..e18efd3 100644 --- a/tests/test_model_repository.py +++ b/tests/test_model_repository.py @@ -27,7 +27,7 @@ import os import pytest from utils import TritonCommands -from triton_cli.server.server_utils import TRTLLMUtils, VLLMUtils +from triton_cli.server.server_utils import TRTLLMUtils, VLLMUtils, LLMAPIUtils from triton_cli.common import DEFAULT_MODEL_REPO # Give ample 30min timeout for tests that download models from huggingface @@ -42,8 +42,10 @@ def setup_method(self): def test_can_not_find_models(self): trtllm_utils = TRTLLMUtils(DEFAULT_MODEL_REPO) vllm_utils = VLLMUtils(DEFAULT_MODEL_REPO) + llmapi_utils = LLMAPIUtils(DEFAULT_MODEL_REPO) assert not trtllm_utils.has_trtllm_model(), f"tensorrtllm model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the tensorrtllm model not found" assert not vllm_utils.has_vllm_model(), f"vllm model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the vllm model not found" + assert not llmapi_utils.has_llmapi_model(), f"llmapi model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the llmapi model not found" @pytest.mark.skipif( os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image" @@ -61,6 +63,22 @@ def test_can_get_tensorrtllm_engine_folder_from_model_repository(self): expected_engine_path == trtllm_utils.get_engine_path() ), f"engine path found is not as expected. Expected: {expected_engine_path}. Found: {trtllm_utils.get_engine_path()}" + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for VLLM image" + ) + @pytest.mark.timeout(DOWNLOAD_TIMEOUT_SECS) + def test_can_get_llmapi_model_id_from_model_repository(self): + model = "llama-2-7b-chat" + expected_model_id = "meta-llama/Llama-2-7b-chat-hf" + TritonCommands._import(model, backend="llmapi") + llmapi_utils = LLMAPIUtils(DEFAULT_MODEL_REPO) + assert ( + llmapi_utils.has_llmapi_model() + ), f"no LLM API model found in model repository: '{DEFAULT_MODEL_REPO}'." + assert ( + expected_model_id == llmapi_utils.get_llmapi_model_huggingface_id_or_path() + ), f"model id found is not as expected. Expected: {expected_model_id}. Found: {llmapi_utils.get_llmapi_model_huggingface_id_or_path()}" + @pytest.mark.skipif( os.environ.get("IMAGE_KIND") != "VLLM", reason="Only run for VLLM image" ) From 8a1cfad47d7c22cdd7862810689b19f1b34f1d23 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 3 Mar 2025 16:03:55 -0800 Subject: [PATCH 2/7] update templates --- .../templates/trt_llm/tensorrt_llm/1/model.py | 142 +++++++++++++++++- .../trt_llm/tensorrt_llm/config.pbtxt | 88 ++++++++++- .../trt_llm/tensorrt_llm_bls/config.pbtxt | 39 ++++- 3 files changed, 261 insertions(+), 8 deletions(-) diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py index 6be7daa..aaa7fcd 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py @@ -9,6 +9,7 @@ from typing import Any, List import numpy as np +import pandas as pd import torch import triton_python_backend_utils as pb_utils from torch import from_numpy @@ -239,7 +240,12 @@ def get_output_config_from_request(request, batch_size=1, batch_index=0): kwargs["return_generation_logits"] = get_input_scalar_by_name( request, 'return_generation_logits', batch_size, batch_index) kwargs["return_perf_metrics"] = get_input_scalar_by_name( - request, 'return_kv_cache_reuse_stats', batch_size, batch_index) + request, 'return_perf_metrics', batch_size, batch_index) + if get_input_scalar_by_name(request, 'return_kv_cache_reuse_stats', + batch_size, batch_index): + pb_utils.Logger.log_warn( + "return_kv_cache_reuse_stats is deprecated, please use return_perf_metrics instead." + ) kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.OutputConfig(**kwargs) @@ -427,6 +433,39 @@ def get_tensor_and_check_length(name: str, expected_length: int): return None +def get_lookahead_decoding_config_from_request(request, + executor_lookahead_config, + batch_size=1, + batch_index=0): + lookahead_window_size = get_input_tensor_by_name(request, + "lookahead_window_size", + batch_size, batch_index) + + lookahead_ngram_size = get_input_tensor_by_name(request, + "lookahead_ngram_size", + batch_size, batch_index) + + lookahead_verification_set_size = get_input_tensor_by_name( + request, "lookahead_verification_set_size", batch_size, batch_index) + + # None lookahead config for requests. + if all(x is None for x in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ]): + return None + + # Have request lookahead config but no executor config. + if executor_lookahead_config is None: + raise RuntimeError( + "The request lookahead decoding input tensors (window_size, ngram_size and verification_set_size) can only be set if the model instance lookahead parameters are also specified" + ) + + return trtllm.LookaheadDecodingConfig(lookahead_window_size, + lookahead_ngram_size, + lookahead_verification_set_size) + + def build_1_2_5_buckets(max_value: int) -> List[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by @@ -450,7 +489,10 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: exponent += 1 -def convert_request(request, exclude_input_from_output, decoupled): +def convert_request(request, + exclude_input_from_output, + decoupled, + executor_lookahead_config=None): inputs = {} input_token_ids = get_input_tensor_by_name(request, 'input_ids') if input_token_ids is None: @@ -526,6 +568,8 @@ def convert_request(request, exclude_input_from_output, decoupled): batch_index) kv_cache_retention_config = get_kv_cache_retention_config_from_request( request, batch_size, batch_index) + request_lookahead_config = get_lookahead_decoding_config_from_request( + request, executor_lookahead_config, batch_size, batch_index) # Inputs for mllama support encoder_input_features = get_input_tensor_by_name( @@ -579,6 +623,7 @@ def convert_request(request, exclude_input_from_output, decoupled): prompt_tuning_config=prompt_tuning_config, lora_config=lora_config, guided_decoding_params=guided_decoding_params, + lookahead_config=request_lookahead_config, kv_cache_retention_config=kv_cache_retention_config)) return requests @@ -674,6 +719,55 @@ def convert_response(response, np.array([kv_cache_metrics.num_total_allocated_blocks], np.int32), 0))) + timing_metrics = result.request_perf_metrics.timing_metrics + output_tensors.append( + pb_utils.Tensor( + "arrival_time_ns", + np.expand_dims( + np.array([pd.Timedelta(timing_metrics.arrival_time).value], + np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "first_scheduled_time_ns", + np.expand_dims( + np.array([ + pd.Timedelta(timing_metrics.first_scheduled_time).value + ], np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "first_token_time_ns", + np.expand_dims( + np.array( + [pd.Timedelta(timing_metrics.first_token_time).value], + np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "last_token_time_ns", + np.expand_dims( + np.array( + [pd.Timedelta(timing_metrics.last_token_time).value], + np.int64), 0))) + + spec_dec_metrics = result.request_perf_metrics.speculative_decoding + output_tensors.append( + pb_utils.Tensor( + "acceptance_rate", + np.expand_dims( + np.array([spec_dec_metrics.acceptance_rate], np.float32), + 0))) + output_tensors.append( + pb_utils.Tensor( + "total_accepted_draft_tokens", + np.expand_dims( + np.array([spec_dec_metrics.total_accepted_draft_tokens], + np.int32), 0))) + output_tensors.append( + pb_utils.Tensor( + "total_draft_tokens", + np.expand_dims( + np.array([spec_dec_metrics.total_draft_tokens], np.int32), + 0))) + return pb_utils.InferenceResponse( output_tensors), result.is_final, output_lengths @@ -830,11 +924,48 @@ def get_peft_cache_config(self, model_config): float), "host_cache_size": get_parameter(model_config, "lora_cache_host_memory_bytes", int), + "lora_prefetch_dir": + get_parameter(model_config, "lora_prefetch_dir", int), } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.PeftCacheConfig(**kwargs) + def get_executor_lookahead_config(self, model_config): + lookahead_window_size = get_parameter(model_config, + "lookahead_window_size", int) + lookahead_ngram_size = get_parameter(model_config, + "lookahead_ngram_size", int) + lookahead_verification_set_size = get_parameter( + model_config, "lookahead_verification_set_size", int) + # executor_lookahead_config is not set + if all(item is None for item in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ]): + return None + + incomplete_config = None in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ] + + assert ( + not incomplete_config + ), "Please set executor_lookahead_window_size, executor_lookahead_ngram_size and executor_lookahead_verification_set_size together." + + return trtllm.LookaheadDecodingConfig(lookahead_window_size, + lookahead_ngram_size, + lookahead_verification_set_size) + def get_decoding_config(self, model_config): + + decoding_mode = convert_decoding_mode( + get_parameter(model_config, "decoding_mode")) + self.executor_lookahead_config = None + if decoding_mode == trtllm.DecodingMode.Lookahead(): + # Add LAD config + self.executor_lookahead_config = self.get_executor_lookahead_config( + model_config) eagle_choices = parse_eagle_choices( get_parameter(model_config, "eagle_choices")) kwargs = { @@ -844,9 +975,10 @@ def get_decoding_config(self, model_config): "eagle_config": None if eagle_choices is None else trtllm.EagleConfig(eagle_choices), + "lookahead_decoding_config": + self.executor_lookahead_config, "decoding_mode": - convert_decoding_mode(get_parameter(model_config, - "decoding_mode")), + decoding_mode, } print(kwargs) kwargs = {k: v for k, v in kwargs.items() if v is not None} @@ -1232,7 +1364,7 @@ def execute(self, requests): try: converted_reqs = convert_request( request, self.exclude_input_from_output, - self.decoupled) + self.decoupled, self.executor_lookahead_config) except Exception as e: response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt index 514b82d..26898c9 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt @@ -276,7 +276,7 @@ input [ optional: true }, { - name: "return_kv_cache_reuse_stats" + name: "return_perf_metrics" data_type: TYPE_BOOL dims: [ 1 ] reshape: { shape: [ ] } @@ -444,6 +444,27 @@ input [ dims: [ 1 ] optional: true allow_ragged_batch: true + }, + { + name: "lookahead_window_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_ngram_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_verification_set_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true } ] output [ @@ -506,6 +527,41 @@ output [ name: "kv_cache_alloc_total_blocks" data_type: TYPE_INT32 dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] } ] instance_group [ @@ -684,6 +740,12 @@ parameters: { string_value: "${lora_cache_host_memory_bytes}" } } +parameters: { + key: "lora_prefetch_dir" + value: { + string_value: "${lora_prefetch_dir}" + } +} parameters: { key: "decoding_mode" value: { @@ -696,6 +758,24 @@ parameters: { string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" } } +parameters: { + key: "lookahead_window_size" + value: { + string_value: "${lookahead_window_size}" + } +} +parameters: { + key: "lookahead_ngram_size" + value: { + string_value: "${lookahead_ngram_size}" + } +} +parameters: { + key: "lookahead_verification_set_size" + value: { + string_value: "${lookahead_verification_set_size}" + } +} parameters: { key: "medusa_choices" value: { @@ -756,3 +836,9 @@ parameters: { string_value: "${guided_decoding_backend}" } } +parameters: { + key: "xgrammar_tokenizer_info_path" + value: { + string_value: "${xgrammar_tokenizer_info_path}" + } +} diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt index 9725a15..0144d74 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt @@ -282,7 +282,7 @@ input [ allow_ragged_batch: true }, { - name: "return_kv_cache_reuse_stats" + name: "return_perf_metrics" data_type: TYPE_BOOL dims: [ 1 ] reshape: { shape: [ ] } @@ -351,6 +351,41 @@ output [ name: "kv_cache_alloc_total_blocks" data_type: TYPE_INT32 dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] } ] @@ -384,4 +419,4 @@ instance_group [ count: ${bls_instance_count} kind : KIND_CPU } -] +] \ No newline at end of file From 726cf93841169dc3a8a0f9525f0b5d595a37047f Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 3 Mar 2025 17:16:01 -0800 Subject: [PATCH 3/7] update --- src/triton_cli/templates/llmapi/1/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/triton_cli/templates/llmapi/1/model.py b/src/triton_cli/templates/llmapi/1/model.py index 85425a3..87542d6 100644 --- a/src/triton_cli/templates/llmapi/1/model.py +++ b/src/triton_cli/templates/llmapi/1/model.py @@ -142,6 +142,9 @@ def _get_sampling_config_from_request(self, request): None if kwargs["top_p"] is None or kwargs["top_p"] <= 0 else kwargs["top_p"] ) + # Remove None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return kwargs @classmethod @@ -244,6 +247,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "name": "max_tokens", "data_type": "TYPE_INT32", "dims": [1], + "optional": True, }, { "name": "stop", From fe33a09d3dc2cec3d5d658d8cb9ab2d5e045d831 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 3 Mar 2025 23:10:07 -0800 Subject: [PATCH 4/7] update readme --- README.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index c0d7c96..3662e1f 100644 --- a/README.md +++ b/README.md @@ -361,18 +361,12 @@ triton profile -m llama-3.1-8b-instruct --service-kind openai --endpoint-type ch ## Serving a HuggingFace LLM Model with LLM API -> [!NOTE] -> LLM API has not yet been integrated into the official triton server tensorrt_llm backend image yet. -> To start the LLM API functionality, the user will only - The LLM API is a high-level Python API and designed for Tensorrt LLM workflows. It could convert a LLM model in Hugging Face format into a Tensorrt LLM engine and serve the engine with a unified Python API without invoking different engine build and converting scripts. To use the LLM API with Triton CLI, import the model with `--backend llmapi` ```bash -export MODEL_NAME="llama-3.1-8b-instruct" -export HF_ID="meta-llama/Llama-3.1-8B-Instruct" -triton import -m $MODEL_NAME --source "hf:$HF_ID" --backend llmapi +triton import -m "llama-3.1-8b-instruct" --backend llmapi ``` Huggingface models will be downloaded at runtime when starting the LLM API engine if not found @@ -383,6 +377,15 @@ startup time. tensorrt_llm>=0.18.0 is required. #### Example ```bash +docker run -ti \ + --gpus all \ + --network=host \ + --shm-size=1g --ulimit memlock=-1 \ + -v /tmp:/tmp \ + -v ${HOME}/models:/root/models \ + -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ + nvcr.io/nvidia/tritonserver:25.03-trtllm-python-py3 + # Install the Triton CLI pip install git+https://github.com/triton-inference-server/triton_cli.git@main @@ -394,7 +397,7 @@ triton remove -m all triton import -m llama-3.1-8b-instruct --backend llmapi # Start Triton pointing at the default model repository -triton start --frontend openai --mode docker +triton start --frontend openai # Interact with model at http://localhost:9000 curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ From 3aae6f060eb0a08867113de67f25bda95eb786d0 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 3 Mar 2025 23:17:53 -0800 Subject: [PATCH 5/7] change test --- tests/test_model_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_repository.py b/tests/test_model_repository.py index e18efd3..ca05bbb 100644 --- a/tests/test_model_repository.py +++ b/tests/test_model_repository.py @@ -64,7 +64,7 @@ def test_can_get_tensorrtllm_engine_folder_from_model_repository(self): ), f"engine path found is not as expected. Expected: {expected_engine_path}. Found: {trtllm_utils.get_engine_path()}" @pytest.mark.skipif( - os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for VLLM image" + os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image" ) @pytest.mark.timeout(DOWNLOAD_TIMEOUT_SECS) def test_can_get_llmapi_model_id_from_model_repository(self): From 4da12cc4a50a44c12fc75a8fe4666fe218efac50 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 6 Mar 2025 11:04:01 -0800 Subject: [PATCH 6/7] fix test --- src/triton_cli/parser.py | 1 + tests/test_e2e.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/triton_cli/parser.py b/src/triton_cli/parser.py index 872a57b..f58ef1d 100755 --- a/src/triton_cli/parser.py +++ b/src/triton_cli/parser.py @@ -68,6 +68,7 @@ "opt125m": "hf:facebook/opt-125m", "mistral-7b": "hf:mistralai/Mistral-7B-v0.1", "falcon-7b": "hf:tiiuae/falcon-7b", + "tinyllama-1.1b-chat-v1.0": "hf:TinyLlama/TinyLlama-1.1B-Chat-v1.0", } diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 88f0e70..3f9f4b6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -149,6 +149,10 @@ def test_vllm_openai_e2e(self, vllm_openai_server): os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image with LLM-API", ) + @pytest.mark.skipif( + os.environ.get("TRTLLM_MODEL") == "gpt2", + reason="LLM API doesn't support gpt2's model architecture yet", + ) @pytest.mark.parametrize( "protocol", [ From 65ed7327e63f1a2e125f285a6b37d2af33f0dadd Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 6 Mar 2025 11:30:20 -0800 Subject: [PATCH 7/7] trigger build --- src/triton_cli/server/server_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/triton_cli/server/server_utils.py b/src/triton_cli/server/server_utils.py index 97d7477..a64b1bf 100755 --- a/src/triton_cli/server/server_utils.py +++ b/src/triton_cli/server/server_utils.py @@ -48,7 +48,7 @@ def get_launch_command( Parameters ---------- server_config : TritonServerConfig - A TritonServerConfig object containing command-line arguments to run tritonserver + A TritonServerConfig object containing command-line arguments to run tritonserver. cmd_as_list : bool Whether the command string needs to be returned as a list of string (local requires list, docker requires str)