diff --git a/README.md b/README.md index 9d47fd1..3662e1f 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Server. ## Table of Contents -| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) | +| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Serving a HuggingFace LLM Model with LLM API](#serving-a-huggingface-llm-model-with-llm-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) | ## Pre-requisites @@ -351,7 +351,59 @@ triton start --frontend openai # Interact with model at http://localhost:9000 curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ "model": "llama-3.1-8b-instruct", - "messages": [{"role": "user", "content": "What is machine learning?"}] + "messages": [{"role": "user", "content": "What is machine learning?"}], + "max_tokens": 256 +}' + +# Profile model with GenAI-Perf +triton profile -m llama-3.1-8b-instruct --service-kind openai --endpoint-type chat --url localhost:9000 --streaming +``` + +## Serving a HuggingFace LLM Model with LLM API + +The LLM API is a high-level Python API and designed for Tensorrt LLM workflows. It could +convert a LLM model in Hugging Face format into a Tensorrt LLM engine and serve the engine with a unified Python API without invoking different +engine build and converting scripts. +To use the LLM API with Triton CLI, import the model with `--backend llmapi` +```bash +triton import -m "llama-3.1-8b-instruct" --backend llmapi +``` + +Huggingface models will be downloaded at runtime when starting the LLM API engine if not found +locally in the HuggingFace cache. No offline engine building step is required, +but you can pre-download the model in advance to avoid downloading at server +startup time. tensorrt_llm>=0.18.0 is required. + +#### Example + +```bash +docker run -ti \ + --gpus all \ + --network=host \ + --shm-size=1g --ulimit memlock=-1 \ + -v /tmp:/tmp \ + -v ${HOME}/models:/root/models \ + -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ + nvcr.io/nvidia/tritonserver:25.03-trtllm-python-py3 + +# Install the Triton CLI +pip install git+https://github.com/triton-inference-server/triton_cli.git@main + +# Authenticate with huggingface for restricted models like Llama-2 and Llama-3 +huggingface-cli login + +# Build TRT LLM engine and generate a Triton model repository pointing at it +triton remove -m all +triton import -m llama-3.1-8b-instruct --backend llmapi + +# Start Triton pointing at the default model repository +triton start --frontend openai + +# Interact with model at http://localhost:9000 +curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ + "model": "llama-3.1-8b-instruct", + "messages": [{"role": "user", "content": "What is machine learning?"}], + "max_tokens": 256 }' # Profile model with GenAI-Perf diff --git a/src/triton_cli/.gitignore b/src/triton_cli/.gitignore index 973a71d..d40f318 100644 --- a/src/triton_cli/.gitignore +++ b/src/triton_cli/.gitignore @@ -1,2 +1,5 @@ *.json *.cache + +# Except model.json from the llmapi template +!templates/llmapi/1/model.json diff --git a/src/triton_cli/common.py b/src/triton_cli/common.py index 8c126af..0f1f46a 100755 --- a/src/triton_cli/common.py +++ b/src/triton_cli/common.py @@ -55,4 +55,4 @@ class TritonCLIException(Exception): DEFAULT_MODEL_REPO: Path = Path.home() / "models" DEFAULT_HF_CACHE: Path = Path.home() / ".cache" / "huggingface" HF_CACHE: Path = Path(os.environ.get("HF_HOME", DEFAULT_HF_CACHE)) -SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm"} +SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm", "llmapi"} diff --git a/src/triton_cli/parser.py b/src/triton_cli/parser.py index 872a57b..f58ef1d 100755 --- a/src/triton_cli/parser.py +++ b/src/triton_cli/parser.py @@ -68,6 +68,7 @@ "opt125m": "hf:facebook/opt-125m", "mistral-7b": "hf:mistralai/Mistral-7B-v0.1", "falcon-7b": "hf:tiiuae/falcon-7b", + "tinyllama-1.1b-chat-v1.0": "hf:TinyLlama/TinyLlama-1.1B-Chat-v1.0", } diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 7728d55..254ecf3 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -67,7 +67,9 @@ SOURCE_PREFIX_NGC = "ngc:" SOURCE_PREFIX_LOCAL = "local:" -TRT_TEMPLATES_PATH = Path(__file__).parent / "templates" / "trt_llm" +TEMPLATES_PATH = Path(__file__).parent / "templates" +TRTLLM_TEMPLATES_PATH = TEMPLATES_PATH / "trt_llm" +LLMAPI_TEMPLATES_PATH = TEMPLATES_PATH / "llmapi" # Support changing destination dynamically to point at # pre-downloaded checkpoints in various circumstances @@ -266,6 +268,8 @@ def __add_huggingface_model( self.remove(model, verbose=False) # Let detailed traceback be reported for TRT-LLM errors for debugging raise e + elif backend == "llmapi": + self.__generate_llmapi_model(version_dir, huggingface_id) else: # TODO: Add generic support for HuggingFace models with HF API. # For now, use vLLM as a means of deploying HuggingFace Transformers @@ -322,6 +326,21 @@ def __generate_vllm_model(self, huggingface_id: str): model_files = {"model.json": model_contents} return model_config, model_files + def __generate_llmapi_model(self, version_dir, huggingface_id: str): + # load the model.json from llmapi template + model_config_file = version_dir / "model.json" + with open(model_config_file) as f: + model_config_str = f.read() + + model_config_json = json.loads(model_config_str) + + # change the model id as the huggingface_id + model_config_json["model"] = huggingface_id + + # write back the model.json + with open(model_config_file, "w") as f: + f.write(json.dumps(model_config_json)) + def __generate_ngc_model(self, name: str, source: str): engines_path = ENGINE_DEST_PATH + "/" + source parse_and_substitute( @@ -392,7 +411,7 @@ def __create_model_repository( ) shutil.copytree( - TRT_TEMPLATES_PATH, + TRTLLM_TEMPLATES_PATH, self.repo, dirs_exist_ok=True, ignore=shutil.ignore_patterns("__pycache__"), @@ -402,6 +421,17 @@ def __create_model_repository( logger.debug(f"Adding TensorRT-LLM models at: {self.repo}") else: version_dir.mkdir(parents=True, exist_ok=False) + if backend == "llmapi": + shutil.copytree( + LLMAPI_TEMPLATES_PATH / "1", + version_dir, + dirs_exist_ok=True, + ignore=shutil.ignore_patterns("__pycache__"), + ) + shutil.copy( + LLMAPI_TEMPLATES_PATH / "config.pbtxt", + model_dir, + ) logger.debug(f"Adding new model to repo at: {version_dir}") except FileExistsError: logger.warning(f"Overwriting existing model in repo at: {version_dir}") diff --git a/src/triton_cli/server/server_factory.py b/src/triton_cli/server/server_factory.py index 921f5f7..c42df23 100755 --- a/src/triton_cli/server/server_factory.py +++ b/src/triton_cli/server/server_factory.py @@ -25,7 +25,7 @@ LOGGER_NAME, TritonCLIException, ) -from .server_utils import TRTLLMUtils, VLLMUtils +from .server_utils import TRTLLMUtils, VLLMUtils, LLMAPIUtils logger = logging.getLogger(LOGGER_NAME) @@ -198,11 +198,14 @@ def _get_openai_chat_template_tokenizer(config): ) trtllm_utils = TRTLLMUtils(config.model_repository) vllm_utils = VLLMUtils(config.model_repository) + llmapi_utils = LLMAPIUtils(config.model_repository) if trtllm_utils.has_trtllm_model(): tokenizer_path = trtllm_utils.get_engine_path() elif vllm_utils.has_vllm_model(): tokenizer_path = vllm_utils.get_vllm_model_huggingface_id_or_path() + elif llmapi_utils.has_llmapi_model(): + tokenizer_path = llmapi_utils.get_llmapi_model_huggingface_id_or_path() else: raise TritonCLIException( "Unable to find a tokenizer to start the Triton OpenAI RESTful API, please use '--openai-chat-template-tokenizer' to specify a tokenizer." diff --git a/src/triton_cli/server/server_utils.py b/src/triton_cli/server/server_utils.py index 98823fe..a64b1bf 100755 --- a/src/triton_cli/server/server_utils.py +++ b/src/triton_cli/server/server_utils.py @@ -48,7 +48,7 @@ def get_launch_command( Parameters ---------- server_config : TritonServerConfig - A TritonServerConfig object containing command-line arguments to run tritonserver + A TritonServerConfig object containing command-line arguments to run tritonserver. cmd_as_list : bool Whether the command string needs to be returned as a list of string (local requires list, docker requires str) @@ -304,3 +304,86 @@ def _find_vllm_model_huggingface_id_or_path(self) -> str: return model_id except OSError: raise Exception(f"Unable to open {model_config_json_file}") + + +class LLMAPIUtils: + """ + A utility class for handling LLMAPI specific models. + """ + + def __init__(self, model_path: Path): + self._model_repo_path = model_path + self._llmapi_model_path = self._find_llmapi_model_path() + self._is_llmapi_model = self._llmapi_model_path is not None + + def has_llmapi_model(self) -> bool: + """ + Returns + ------- + A boolean indicating whether a LLMAPI model exists in the model repo + """ + return self._is_llmapi_model + + def get_llmapi_model_huggingface_id_or_path(self) -> str: + """ + Returns + ------- + The LLMAPI model's Huggingface Id or path + """ + return self._find_llmapi_model_huggingface_id_or_path() + + def _find_llmapi_model_path(self) -> Path: + """ + Returns + ------- + A pathlib.Path object containing the path to the LLMAPI model folder. + Assumptions + ---------- + - Assumes only a single model uses the LLMAPI backend (could have multiple models) + """ + # Search the llmapi model from all models in model repository + model_dirs = [ + model_dir + for model_dir in self._model_repo_path.iterdir() + if model_dir.is_dir() + ] + for model_dir in model_dirs: + model_config_file = Path(self._model_repo_path) / model_dir / "config.pbtxt" + model_json_path = model_config_file.parent / "1" / "model.json" + # check if config.pbtxt exists + if model_config_file.is_file(): + # read the config.pbtxt file and identify the backend + with open(model_config_file) as config_file: + config = text_format.Parse(config_file.read(), mc.ModelConfig()) + json_config = json.loads( + json_format.MessageToJson( + config, preserving_proto_field_name=True + ) + ) + # check if the model.json also exists. + if json_config["backend"] == "python" and model_json_path.is_file(): + return model_config_file.parent + + return None + + def _find_llmapi_model_huggingface_id_or_path(self) -> str: + """ + Returns + ------- + The llmapi model's Huggingface Id or path + """ + assert self._is_llmapi_model, "model Huggingface Id or path cannot be parsed from a model repository that does not contain a LLMAPI model." + try: + # assume the version is always "1" + model_version_path = self._llmapi_model_path / "1" + model_config_json_file = model_version_path / "model.json" + with open(model_config_json_file) as json_data: + data = json.load(json_data) + model_id = data.get("model") + if not model_id: + raise Exception( + f"Unable to parse config from {model_config_json_file}" + ) + return model_id + except OSError: + raise Exception(f"Unable to open {model_config_json_file}") diff --git a/src/triton_cli/templates/llmapi/1/model.json b/src/triton_cli/templates/llmapi/1/model.json new file mode 100644 index 0000000..d659012 --- /dev/null +++ b/src/triton_cli/templates/llmapi/1/model.json @@ -0,0 +1,28 @@ +{ + "max_batch_size": 64, + "decoupled": true, + + "model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tokenizer": null, + "tokenizer_mode": null, + "skip_tokenizer_init": null, + "trust_remote_code": null, + "tensor_parallel_size": null, + "pipeline_parallel_size": null, + "dtype": null, + "revision": null, + "tokenizer_revision": null, + "speculative_model": null, + "enable_chunked_prefill": null, + "num_instances": null, + + "use_cuda_graph": null, + "cuda_graph_batch_sizes": null, + "cuda_graph_max_batch_size": null, + "cuda_graph_padding_enabled": null, + "enable_overlap_scheduler": null, + "kv_cache_dtype": null, + "torch_compile_enabled": null, + "torch_compile_fullgraph": null, + "torch_compile_inductor_enabled": null +} diff --git a/src/triton_cli/templates/llmapi/1/model.py b/src/triton_cli/templates/llmapi/1/model.py new file mode 100644 index 0000000..87542d6 --- /dev/null +++ b/src/triton_cli/templates/llmapi/1/model.py @@ -0,0 +1,707 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import asyncio +import gc +import json +import os +import queue +import threading +from contextlib import asynccontextmanager + +import numpy as np +import triton_python_backend_utils as pb_utils +from mpi4py.futures import MPICommExecutor +from mpi4py.MPI import COMM_WORLD + +from tensorrt_llm import SamplingParams +from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig +from tensorrt_llm._utils import global_mpi_rank +from tensorrt_llm.llmapi import LLM + +_TRTLLM_ENGINE_ARGS_FILENAME = "model.json" + + +class TritonPythonModel: + # Define the expected keys for each config + # TODO: Add more keys as needed + PYTORCH_CONFIG_KEYS = { + "use_cuda_graph", + "cuda_graph_batch_sizes", + "cuda_graph_max_batch_size", + "cuda_graph_padding_enabled", + "enable_overlap_scheduler", + "kv_cache_dtype", + "torch_compile_enabled", + "torch_compile_fullgraph", + "torch_compile_inductor_enabled", + } + + LLM_ENGINE_KEYS = { + "model", + "tokenizer", + "tokenizer_mode", + "skip_tokenizer_init", + "trust_remote_code", + "tensor_parallel_size", + "pipeline_parallel_size", + "dtype", + "revision", + "tokenizer_revision", + "speculative_model", + "enable_chunked_prefill", + } + + def _get_input_scalar_by_name(self, request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return None + + tensor = tensor.as_numpy() + if tensor.size == 0: + return None + + return tensor.item(0) + + def _get_string_list_by_name(self, request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return None + + tensor = tensor.as_numpy() + if tensor.size == 0: + return None + + # Convert to list and handle bytes conversion + if isinstance(tensor, np.ndarray): + if tensor.ndim == 0: + item = tensor.item() + return [item.decode("utf-8") if isinstance(item, bytes) else str(item)] + + return [ + item.decode("utf-8") if isinstance(item, bytes) else str(item) + for item in tensor.flatten() + ] + + # Fallback case + if isinstance(tensor, bytes): + return [tensor.decode("utf-8")] + return [str(tensor)] + + def _get_sampling_config_from_request(self, request): + # TODO: Add more sampling parameters as needed + kwargs = { + "beam_width": self._get_input_scalar_by_name(request, "beam_width") or 1, + "temperature": self._get_input_scalar_by_name(request, "temperature"), + "top_k": self._get_input_scalar_by_name(request, "top_k"), + "top_p": self._get_input_scalar_by_name(request, "top_p"), + "frequency_penalty": self._get_input_scalar_by_name( + request, "frequency_penalty" + ), + "presence_penalty": self._get_input_scalar_by_name( + request, "presence_penalty" + ), + "max_tokens": self._get_input_scalar_by_name(request, "max_tokens"), + # stop_words is deprecated. Should use stop instead. + "stop": ( + self._get_string_list_by_name(request, "stop") + or self._get_string_list_by_name(request, "stop_words") + ), + # random_seed is deprecated. Should use seed instead. + "seed": ( + self._get_input_scalar_by_name(request, "seed") + or self._get_input_scalar_by_name(request, "random_seed") + ), + } + + # Adjust top_p if it's not valid + kwargs["top_p"] = ( + None if kwargs["top_p"] is None or kwargs["top_p"] <= 0 else kwargs["top_p"] + ) + + # Remove None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return kwargs + + @classmethod + def auto_complete_config(cls, auto_complete_model_config): + # Add inputs/outputs to the model config. + cls._auto_complete_inputs_and_outputs(auto_complete_model_config) + + # Get the max batch size and decoupled model transaction policy from the json file. + engine_args_filepath = os.path.join( + pb_utils.get_model_dir(), _TRTLLM_ENGINE_ARGS_FILENAME + ) + assert os.path.isfile( + engine_args_filepath + ), f"'{_TRTLLM_ENGINE_ARGS_FILENAME}' containing TRT-LLM engine args must be provided in '{pb_utils.get_model_dir()}'" + with open(engine_args_filepath) as file: + # The Python interpreter used to invoke this function will be destroyed upon returning from this function and as a result none of the objects created here will be available in the initialize, execute, or finalize functions. + trtllm_engine_config = json.load(file) + + model_config_keys = {"max_batch_size", "decoupled"} + auto_complete_config = { + k: v for k, v in trtllm_engine_config.items() if k in model_config_keys + } + + # Set the max batch size and decoupled model transaction policy in the model config. + is_decoupled = auto_complete_config.get("decoupled", False) + auto_complete_model_config.set_model_transaction_policy( + dict(decoupled=is_decoupled) + ) + max_batch_size = auto_complete_config.get("max_batch_size", 64) + auto_complete_model_config.set_max_batch_size(int(max_batch_size)) + + return auto_complete_model_config + + @staticmethod + def _auto_complete_inputs_and_outputs(auto_complete_model_config): + # Inputs expected by the backend. + inputs = [ + {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, + { + "name": "stream", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "exclude_input_in_output", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_finish_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_stop_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "temperature", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "beam_width", + "data_type": "TYPE_INT32", + "dims": [1], + "optional": True, + }, + { + "name": "top_k", + "data_type": "TYPE_INT32", + "dims": [1], + "optional": True, + }, + { + "name": "top_p", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "frequency_penalty", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "presence_penalty", + "data_type": "TYPE_FP32", + "dims": [1], + "optional": True, + }, + { + "name": "max_tokens", + "data_type": "TYPE_INT32", + "dims": [1], + "optional": True, + }, + { + "name": "stop", + "data_type": "TYPE_STRING", + "dims": [-1], + "optional": True, + }, + { + # stop_words is deprecated. Should use stop instead. + "name": "stop_words", + "data_type": "TYPE_STRING", + "dims": [-1], + "optional": True, + }, + { + "name": "seed", + "data_type": "TYPE_UINT64", + "dims": [1], + "optional": True, + }, + { + # random_seed is deprecated. Should use seed instead. + "name": "random_seed", + "data_type": "TYPE_UINT64", + "dims": [1], + "optional": True, + }, + ] + # Outputs expected by the backend. + outputs = [ + {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "stop_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, + ] + + # Collect input and output names from the provided model config. + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + for input in config["input"]: + input_names.append(input["name"]) + for output in config["output"]: + output_names.append(output["name"]) + + # Add missing inputs and outputs to the model config. + for input in inputs: + if input["name"] not in input_names: + auto_complete_model_config.add_input(input) + for output in outputs: + if output["name"] not in output_names: + auto_complete_model_config.add_output(output) + + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = pb_utils.using_decoupled_model_transaction_policy( + self.model_config + ) + self.params = self.model_config["parameters"] + self.logger = pb_utils.Logger + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + if global_mpi_rank() == 0: + # Initialize engine arguments + self._init_engine_args() + self.logger.log_info( + f"[trtllm] rank{global_mpi_rank()} is starting trtllm engine with args: {self.llm_engine_args}" + ) + + # Starting the TRT-LLM engine with LLM API and its event thread running the AsyncIO event loop. + self._init_engine() + + # Starting the response thread. It allows TRT-LLM to keep making progress while + # response sender(s) are sending responses to server frontend. + self._response_queue = queue.Queue() + self._response_thread = threading.Thread(target=self._response_loop) + self._response_thread.start() + else: + self.logger.log_info( + f"[trtllm] rank{global_mpi_rank()} is waiting for the leader node..." + ) + with MPICommExecutor(COMM_WORLD) as executor: + if executor is not None: + raise RuntimeError( + f"[trtllm] rank{COMM_WORLD.rank} should not have executor" + ) + return + + def _get_llm_args(self, args_dict): + pytorch_config_args = { + k: v + for k, v in args_dict.items() + if k in self.PYTORCH_CONFIG_KEYS and v is not None + } + llm_engine_args = { + k: v + for k, v in args_dict.items() + if k in self.LLM_ENGINE_KEYS and v is not None + } + if "model" not in llm_engine_args: + raise pb_utils.TritonModelException( + "Model name is required in the TRT-LLM engine config." + ) + + return pytorch_config_args, llm_engine_args + + def _init_engine_args(self): + """Initialize engine arguments from config file.""" + engine_args_filepath = os.path.join( + pb_utils.get_model_dir(), _TRTLLM_ENGINE_ARGS_FILENAME + ) + if not os.path.isfile(engine_args_filepath): + raise pb_utils.TritonModelException( + f"'{_TRTLLM_ENGINE_ARGS_FILENAME}' containing TRT-LLM engine args must be provided in '{pb_utils.get_model_dir()}'" + ) + + try: + with open(engine_args_filepath) as file: + self.trtllm_engine_config = json.load(file) + except json.JSONDecodeError as e: + raise pb_utils.TritonModelException(f"Failed to parse engine config: {e}") + + self.pytorch_config_args, self.llm_engine_args = self._get_llm_args( + self.trtllm_engine_config + ) + + def _init_engine(self): + # Run the engine in a separate thread running the AsyncIO event loop. + self._llm_engine = None + self._llm_engine_start_cv = threading.Condition() + self._llm_engine_shutdown_event = asyncio.Event() + self._event_thread = threading.Thread( + target=asyncio.run, args=(self._run_llm_engine(),) + ) + self._event_thread.start() + with self._llm_engine_start_cv: + while self._llm_engine is None: + self._llm_engine_start_cv.wait() + + # The 'threading.Thread()' will not raise the exception here should the engine + # failed to start, so the exception is passed back via the engine variable. + if isinstance(self._llm_engine, Exception): + e = self._llm_engine + self.logger.log_error(f"[trtllm] Failed to start engine: {e}") + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + raise e + + async def _run_llm_engine(self): + # Counter to keep track of ongoing request counts. + self._ongoing_request_count = 0 + + @asynccontextmanager + async def async_llm_wrapper(): + # Create LLM in a thread to avoid blocking + loop = asyncio.get_running_loop() + try: + pytorch_config = PyTorchConfig(**self.pytorch_config_args) + llm = await loop.run_in_executor( + None, + lambda: LLM( + **self.llm_engine_args, + backend="pytorch", + pytorch_backend_config=pytorch_config, + ), + ) + yield llm + finally: + if "llm" in locals(): + # Run shutdown in a thread to avoid blocking + await loop.run_in_executor(None, llm.shutdown) + + try: + async with async_llm_wrapper() as engine: + # Capture the engine event loop and make it visible to other threads. + self._event_loop = asyncio.get_running_loop() + + # Signal the engine is started and make it visible to other threads. + with self._llm_engine_start_cv: + self._llm_engine = engine + self._llm_engine_start_cv.notify_all() + + # Wait for the engine shutdown signal. + await self._llm_engine_shutdown_event.wait() + + # Wait for the ongoing requests to complete. + while self._ongoing_request_count > 0: + self.logger.log_info( + "[trtllm] Awaiting remaining {} requests".format( + self._ongoing_request_count + ) + ) + await asyncio.sleep(1) + + # Cancel all tasks in the event loop. + for task in asyncio.all_tasks(loop=self._event_loop): + if task is not asyncio.current_task(): + task.cancel() + + except Exception as e: + # Signal and pass the exception back via the engine variable if the engine + # failed to start. If the engine has started, re-raise the exception. + with self._llm_engine_start_cv: + if self._llm_engine is None: + self._llm_engine = e + self._llm_engine_start_cv.notify_all() + return + raise e + + self._llm_engine = None + self.logger.log_info("[trtllm] Shutdown complete") + + def _response_loop(self): + while True: + item = self._response_queue.get() + # To signal shutdown a None item will be added to the queue. + if item is None: + break + response_state, response, response_flag = item + response_sender = response_state["response_sender"] + try: + response_sender.send(response, response_flag) + # Stop checking for cancellation if the last response is generated. + if not response_state["last_response_generated"]: + response_state["is_cancelled"] = response_sender.is_cancelled() + except Exception as e: + self.logger.log_error( + f"An error occurred while sending a response: {e}" + ) + finally: + if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: + self._ongoing_request_count -= 1 + + def execute(self, requests): + # TODO: Add health check here? + for request in requests: + # TODO : Verify Lora + if request is not None: + assert ( + self._llm_engine_shutdown_event.is_set() is False + ), "Cannot create tasks after shutdown has been requested" + coro = self._generate(request) + asyncio.run_coroutine_threadsafe(coro, self._event_loop) + + return None + + async def _generate(self, request): + response_sender = request.get_response_sender() + response_state = { + "response_sender": response_sender, + "is_cancelled": False, + "last_response_generated": False, # last response ready but not yet sent + } + self._ongoing_request_count += 1 + decrement_ongoing_request_count = True + try: + ( + prompt, + stream, + prepend_input, + sampling_config, + additional_outputs, + ) = self._get_input_tensors(request) + + sampling_params = SamplingParams(**sampling_config) + + # Generate the response. + response_iterator = self._llm_engine.generate_async( + prompt, sampling_params, streaming=stream + ) + + request_output_state = {} + async for request_output in response_iterator: + # TODO: Add request cancellation check here + # Send each response if streaming. + if stream: + response = self._create_response( + request_output_state, + request_output, + prepend_input=False, + additional_outputs=additional_outputs, + ) + flags = 0 + if request_output.finished: + response_state["last_response_generated"] = True + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait((response_state, response, flags)) + + # Send the last response which contains all the outputs if not streaming. + if not stream: + response_sender.send( + self._create_response( + request_output_state={}, + request_output=request_output, + prepend_input=prepend_input, + additional_outputs=additional_outputs, + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + + except Exception as e: + self.logger.log_error(f"[trtllm] Error generating stream: {e}") + error = pb_utils.TritonError(f"Error generating stream: {e}") + text_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) + ) + response = pb_utils.InferenceResponse( + output_tensors=[text_output_tensor], error=error + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + raise e + + finally: + if decrement_ongoing_request_count: + self._ongoing_request_count -= 1 + + def _get_input_tensors(self, request): + # Parse the prompt based on the batch size. + text_input = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy() + prompt = ( + text_input[0][0] + if self.model_config["max_batch_size"] > 0 + else text_input[0] + ) + + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # stream + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream and not self.decoupled: + raise pb_utils.TritonModelException( + "Streaming is only supported in decoupled mode." + ) + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + # Default to False if not specified + prepend_input = False + if prepend_input and stream: + raise pb_utils.TritonModelException( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # Sampling parameters + sampling_config = self._get_sampling_config_from_request(request) + + # additional outputs + additional_outputs = { + "return_finish_reason": None, + "return_stop_reason": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, sampling_config, additional_outputs + + def _create_response( + self, request_output_state, request_output, prepend_input, additional_outputs + ): + # TODO: Check if request_output has_error and handle it + output_tensors = [] + + # text_output + prepend_prompt = "" + if "prev_lens_text_output" not in request_output_state: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + request_output_state["prev_lens_text_output"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_text_output"] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_text_output"] = [ + len(output.text) for output in request_output.outputs + ] + + # finish_reason + if additional_outputs["return_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) + + # stop_reason + if additional_outputs["return_stop_reason"]: + stop_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "stop_reason", np.asarray(stop_reason, dtype=np.object_) + ) + ) + + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) + + def finalize(self): + self.logger.log_info("[trtllm] Issuing finalize to trtllm backend") + self._event_loop.call_soon_threadsafe(self._llm_engine_shutdown_event.set) + + # Shutdown the event thread. + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + + # # Shutdown the response thread. + self._response_queue.put(None) + if self._response_thread is not None: + self._response_thread.join() + self._response_thread = None + + # When using parallel tensors, the stub process may not shutdown due to + # unreleased references, so manually run the garbage collector once. + self.logger.log_info("[trtllm] Running Garbage Collector on finalize...") + gc.collect() + self.logger.log_info("[trtllm] Garbage Collector on finalize... done") diff --git a/src/triton_cli/templates/llmapi/config.pbtxt b/src/triton_cli/templates/llmapi/config.pbtxt new file mode 100644 index 0000000..65873aa --- /dev/null +++ b/src/triton_cli/templates/llmapi/config.pbtxt @@ -0,0 +1,34 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py index 6be7daa..aaa7fcd 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py @@ -9,6 +9,7 @@ from typing import Any, List import numpy as np +import pandas as pd import torch import triton_python_backend_utils as pb_utils from torch import from_numpy @@ -239,7 +240,12 @@ def get_output_config_from_request(request, batch_size=1, batch_index=0): kwargs["return_generation_logits"] = get_input_scalar_by_name( request, 'return_generation_logits', batch_size, batch_index) kwargs["return_perf_metrics"] = get_input_scalar_by_name( - request, 'return_kv_cache_reuse_stats', batch_size, batch_index) + request, 'return_perf_metrics', batch_size, batch_index) + if get_input_scalar_by_name(request, 'return_kv_cache_reuse_stats', + batch_size, batch_index): + pb_utils.Logger.log_warn( + "return_kv_cache_reuse_stats is deprecated, please use return_perf_metrics instead." + ) kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.OutputConfig(**kwargs) @@ -427,6 +433,39 @@ def get_tensor_and_check_length(name: str, expected_length: int): return None +def get_lookahead_decoding_config_from_request(request, + executor_lookahead_config, + batch_size=1, + batch_index=0): + lookahead_window_size = get_input_tensor_by_name(request, + "lookahead_window_size", + batch_size, batch_index) + + lookahead_ngram_size = get_input_tensor_by_name(request, + "lookahead_ngram_size", + batch_size, batch_index) + + lookahead_verification_set_size = get_input_tensor_by_name( + request, "lookahead_verification_set_size", batch_size, batch_index) + + # None lookahead config for requests. + if all(x is None for x in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ]): + return None + + # Have request lookahead config but no executor config. + if executor_lookahead_config is None: + raise RuntimeError( + "The request lookahead decoding input tensors (window_size, ngram_size and verification_set_size) can only be set if the model instance lookahead parameters are also specified" + ) + + return trtllm.LookaheadDecodingConfig(lookahead_window_size, + lookahead_ngram_size, + lookahead_verification_set_size) + + def build_1_2_5_buckets(max_value: int) -> List[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by @@ -450,7 +489,10 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: exponent += 1 -def convert_request(request, exclude_input_from_output, decoupled): +def convert_request(request, + exclude_input_from_output, + decoupled, + executor_lookahead_config=None): inputs = {} input_token_ids = get_input_tensor_by_name(request, 'input_ids') if input_token_ids is None: @@ -526,6 +568,8 @@ def convert_request(request, exclude_input_from_output, decoupled): batch_index) kv_cache_retention_config = get_kv_cache_retention_config_from_request( request, batch_size, batch_index) + request_lookahead_config = get_lookahead_decoding_config_from_request( + request, executor_lookahead_config, batch_size, batch_index) # Inputs for mllama support encoder_input_features = get_input_tensor_by_name( @@ -579,6 +623,7 @@ def convert_request(request, exclude_input_from_output, decoupled): prompt_tuning_config=prompt_tuning_config, lora_config=lora_config, guided_decoding_params=guided_decoding_params, + lookahead_config=request_lookahead_config, kv_cache_retention_config=kv_cache_retention_config)) return requests @@ -674,6 +719,55 @@ def convert_response(response, np.array([kv_cache_metrics.num_total_allocated_blocks], np.int32), 0))) + timing_metrics = result.request_perf_metrics.timing_metrics + output_tensors.append( + pb_utils.Tensor( + "arrival_time_ns", + np.expand_dims( + np.array([pd.Timedelta(timing_metrics.arrival_time).value], + np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "first_scheduled_time_ns", + np.expand_dims( + np.array([ + pd.Timedelta(timing_metrics.first_scheduled_time).value + ], np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "first_token_time_ns", + np.expand_dims( + np.array( + [pd.Timedelta(timing_metrics.first_token_time).value], + np.int64), 0))) + output_tensors.append( + pb_utils.Tensor( + "last_token_time_ns", + np.expand_dims( + np.array( + [pd.Timedelta(timing_metrics.last_token_time).value], + np.int64), 0))) + + spec_dec_metrics = result.request_perf_metrics.speculative_decoding + output_tensors.append( + pb_utils.Tensor( + "acceptance_rate", + np.expand_dims( + np.array([spec_dec_metrics.acceptance_rate], np.float32), + 0))) + output_tensors.append( + pb_utils.Tensor( + "total_accepted_draft_tokens", + np.expand_dims( + np.array([spec_dec_metrics.total_accepted_draft_tokens], + np.int32), 0))) + output_tensors.append( + pb_utils.Tensor( + "total_draft_tokens", + np.expand_dims( + np.array([spec_dec_metrics.total_draft_tokens], np.int32), + 0))) + return pb_utils.InferenceResponse( output_tensors), result.is_final, output_lengths @@ -830,11 +924,48 @@ def get_peft_cache_config(self, model_config): float), "host_cache_size": get_parameter(model_config, "lora_cache_host_memory_bytes", int), + "lora_prefetch_dir": + get_parameter(model_config, "lora_prefetch_dir", int), } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.PeftCacheConfig(**kwargs) + def get_executor_lookahead_config(self, model_config): + lookahead_window_size = get_parameter(model_config, + "lookahead_window_size", int) + lookahead_ngram_size = get_parameter(model_config, + "lookahead_ngram_size", int) + lookahead_verification_set_size = get_parameter( + model_config, "lookahead_verification_set_size", int) + # executor_lookahead_config is not set + if all(item is None for item in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ]): + return None + + incomplete_config = None in [ + lookahead_window_size, lookahead_ngram_size, + lookahead_verification_set_size + ] + + assert ( + not incomplete_config + ), "Please set executor_lookahead_window_size, executor_lookahead_ngram_size and executor_lookahead_verification_set_size together." + + return trtllm.LookaheadDecodingConfig(lookahead_window_size, + lookahead_ngram_size, + lookahead_verification_set_size) + def get_decoding_config(self, model_config): + + decoding_mode = convert_decoding_mode( + get_parameter(model_config, "decoding_mode")) + self.executor_lookahead_config = None + if decoding_mode == trtllm.DecodingMode.Lookahead(): + # Add LAD config + self.executor_lookahead_config = self.get_executor_lookahead_config( + model_config) eagle_choices = parse_eagle_choices( get_parameter(model_config, "eagle_choices")) kwargs = { @@ -844,9 +975,10 @@ def get_decoding_config(self, model_config): "eagle_config": None if eagle_choices is None else trtllm.EagleConfig(eagle_choices), + "lookahead_decoding_config": + self.executor_lookahead_config, "decoding_mode": - convert_decoding_mode(get_parameter(model_config, - "decoding_mode")), + decoding_mode, } print(kwargs) kwargs = {k: v for k, v in kwargs.items() if v is not None} @@ -1232,7 +1364,7 @@ def execute(self, requests): try: converted_reqs = convert_request( request, self.exclude_input_from_output, - self.decoupled) + self.decoupled, self.executor_lookahead_config) except Exception as e: response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt index 514b82d..26898c9 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt @@ -276,7 +276,7 @@ input [ optional: true }, { - name: "return_kv_cache_reuse_stats" + name: "return_perf_metrics" data_type: TYPE_BOOL dims: [ 1 ] reshape: { shape: [ ] } @@ -444,6 +444,27 @@ input [ dims: [ 1 ] optional: true allow_ragged_batch: true + }, + { + name: "lookahead_window_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_ngram_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_verification_set_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true } ] output [ @@ -506,6 +527,41 @@ output [ name: "kv_cache_alloc_total_blocks" data_type: TYPE_INT32 dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] } ] instance_group [ @@ -684,6 +740,12 @@ parameters: { string_value: "${lora_cache_host_memory_bytes}" } } +parameters: { + key: "lora_prefetch_dir" + value: { + string_value: "${lora_prefetch_dir}" + } +} parameters: { key: "decoding_mode" value: { @@ -696,6 +758,24 @@ parameters: { string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" } } +parameters: { + key: "lookahead_window_size" + value: { + string_value: "${lookahead_window_size}" + } +} +parameters: { + key: "lookahead_ngram_size" + value: { + string_value: "${lookahead_ngram_size}" + } +} +parameters: { + key: "lookahead_verification_set_size" + value: { + string_value: "${lookahead_verification_set_size}" + } +} parameters: { key: "medusa_choices" value: { @@ -756,3 +836,9 @@ parameters: { string_value: "${guided_decoding_backend}" } } +parameters: { + key: "xgrammar_tokenizer_info_path" + value: { + string_value: "${xgrammar_tokenizer_info_path}" + } +} diff --git a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt index 9725a15..0144d74 100644 --- a/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt +++ b/src/triton_cli/templates/trt_llm/tensorrt_llm_bls/config.pbtxt @@ -282,7 +282,7 @@ input [ allow_ragged_batch: true }, { - name: "return_kv_cache_reuse_stats" + name: "return_perf_metrics" data_type: TYPE_BOOL dims: [ 1 ] reshape: { shape: [ ] } @@ -351,6 +351,41 @@ output [ name: "kv_cache_alloc_total_blocks" data_type: TYPE_INT32 dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] } ] @@ -384,4 +419,4 @@ instance_group [ count: ${bls_instance_count} kind : KIND_CPU } -] +] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 7527be6..3f4ca44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,18 @@ def vllm_server(): server.stop() +@pytest.fixture(scope="function") +def llmapi_server(): + llm_repo = None + + # llmapi models might need to be downloaded on the fly during model loading, + # so leave longer timeout in case of slow network. + server = ScopedTritonServer(repo=llm_repo, timeout=1800) + yield server + # Ensure server is cleaned up after each test + server.stop() + + @pytest.fixture(scope="function") def trtllm_openai_server(): llm_repo = None @@ -92,6 +104,18 @@ def vllm_openai_server(): server.stop() +@pytest.fixture(scope="function") +def llmapi_openai_server(): + llm_repo = None + + # llmapi models might need to be downloaded on the fly during model loading, + # so leave longer timeout in case of slow network. + server = ScopedTritonServer(repo=llm_repo, timeout=1800, frontend="openai") + yield server + # Ensure server is cleaned up after each test + server.stop() + + @pytest.fixture(scope="function") def simple_server(): test_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 78abc37..14625e1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -38,6 +38,8 @@ CUSTOM_TRTLLM_MODEL_SOURCES = [("trtllm-model", "hf:gpt2")] +CUSTOM_LLMAPI_MODEL_SOURCES = [("llmapi-model", "hf:gpt2")] + # TODO: Add public NGC model for testing CUSTOM_NGC_MODEL_SOURCES = [("my-llm", "ngc:does-not-exist")] @@ -96,6 +98,16 @@ def test_repo_add_trtllm_build(self, model, source): TritonCommands._import(model, source=source, backend="tensorrtllm") TritonCommands._clear() + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRTLLM image" + ) + @pytest.mark.parametrize("model,source", CUSTOM_TRTLLM_MODEL_SOURCES) + def test_repo_add_llmapi_build(self, model, source): + # TODO: Parse repo to find TRT-LLM models and backend in config + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + TritonCommands._clear() + @pytest.mark.skip(reason="Pre-built TRT-LLM engines not available") def test_import_trtllm_prebuilt(self, model, source): # TODO: Parse repo to find TRT-LLM models and backend in config diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 2a491ca..3f9f4b6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -145,6 +145,63 @@ def test_vllm_openai_e2e(self, vllm_openai_server): model, service_kind="openai", endpoint_type="chat", url="localhost:9000" ) + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", + reason="Only run for TRT-LLM image with LLM-API", + ) + @pytest.mark.skipif( + os.environ.get("TRTLLM_MODEL") == "gpt2", + reason="LLM API doesn't support gpt2's model architecture yet", + ) + @pytest.mark.parametrize( + "protocol", + [ + "grpc", + pytest.param( + "http", + # NOTE: skip because xfail was causing server to not get cleaned up by test in background + marks=pytest.mark.skip( + reason="http not supported decoupled models and model profiling yet" + ), + ), + ], + ) + @pytest.mark.timeout(LLM_TIMEOUT_SECS) + def test_llmapi_e2e(self, llmapi_server, protocol): + # NOTE: llmapi "backend" is using the same api format as tensorrtllm(tensorrt_llm_bls) backend, + # Use the tensorrtllm backend as the profiling option. + model = os.environ.get("TRTLLM_MODEL") + assert model is not None, "TRTLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + llmapi_server.start() + TritonCommands._infer(model, prompt=PROMPT, protocol=protocol) + # TODO: update this to llmapi when genai-perf supports the llmapi backend's api format + TritonCommands._profile(model, backend="tensorrtllm") + + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", + reason="Only run for TRT-LLM image with LLM-API", + ) + @pytest.mark.skipif( + os.environ.get("TRTLLM_MODEL") == "gpt2", + reason="gpt2's tokenizer doesn't have a chat template defined", + ) + @pytest.mark.timeout(LLM_TIMEOUT_SECS) + def test_llmapi_openai_e2e(self, vllm_openai_server): + model = os.environ.get("TRTLLM_MODEL") + assert model is not None, "TRTLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") + TritonCommands._clear() + TritonCommands._import(model, source=source, backend="llmapi") + vllm_openai_server.start() + TritonCommands._profile( + model, service_kind="openai", endpoint_type="chat", url="localhost:9000" + ) + @pytest.mark.skipif( os.environ.get("CI_PIPELINE") == "GITHUB_ACTIONS", reason="bandage/temporary fix", diff --git a/tests/test_model_repository.py b/tests/test_model_repository.py index 736b701..ca05bbb 100644 --- a/tests/test_model_repository.py +++ b/tests/test_model_repository.py @@ -27,7 +27,7 @@ import os import pytest from utils import TritonCommands -from triton_cli.server.server_utils import TRTLLMUtils, VLLMUtils +from triton_cli.server.server_utils import TRTLLMUtils, VLLMUtils, LLMAPIUtils from triton_cli.common import DEFAULT_MODEL_REPO # Give ample 30min timeout for tests that download models from huggingface @@ -42,8 +42,10 @@ def setup_method(self): def test_can_not_find_models(self): trtllm_utils = TRTLLMUtils(DEFAULT_MODEL_REPO) vllm_utils = VLLMUtils(DEFAULT_MODEL_REPO) + llmapi_utils = LLMAPIUtils(DEFAULT_MODEL_REPO) assert not trtllm_utils.has_trtllm_model(), f"tensorrtllm model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the tensorrtllm model not found" assert not vllm_utils.has_vllm_model(), f"vllm model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the vllm model not found" + assert not llmapi_utils.has_llmapi_model(), f"llmapi model found in model repository: '{DEFAULT_MODEL_REPO}', but the test expect the llmapi model not found" @pytest.mark.skipif( os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image" @@ -61,6 +63,22 @@ def test_can_get_tensorrtllm_engine_folder_from_model_repository(self): expected_engine_path == trtllm_utils.get_engine_path() ), f"engine path found is not as expected. Expected: {expected_engine_path}. Found: {trtllm_utils.get_engine_path()}" + @pytest.mark.skipif( + os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image" + ) + @pytest.mark.timeout(DOWNLOAD_TIMEOUT_SECS) + def test_can_get_llmapi_model_id_from_model_repository(self): + model = "llama-2-7b-chat" + expected_model_id = "meta-llama/Llama-2-7b-chat-hf" + TritonCommands._import(model, backend="llmapi") + llmapi_utils = LLMAPIUtils(DEFAULT_MODEL_REPO) + assert ( + llmapi_utils.has_llmapi_model() + ), f"no LLM API model found in model repository: '{DEFAULT_MODEL_REPO}'." + assert ( + expected_model_id == llmapi_utils.get_llmapi_model_huggingface_id_or_path() + ), f"model id found is not as expected. Expected: {expected_model_id}. Found: {llmapi_utils.get_llmapi_model_huggingface_id_or_path()}" + @pytest.mark.skipif( os.environ.get("IMAGE_KIND") != "VLLM", reason="Only run for VLLM image" )