From d6fa943f32b2703f8386d8a2f65c16c8a902c962 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 12 Jan 2026 20:59:20 -0800 Subject: [PATCH 1/4] Add endpoint startup metrics design document Design for instrumenting model endpoint startup to capture detailed timing metrics (pod scheduling, model download, vLLM init, server ready) using OpenTelemetry with Datadog as the visualization layer. Co-Authored-By: Claude Opus 4.5 --- ...6-01-12-endpoint-startup-metrics-design.md | 532 ++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 model-engine/docs/plans/2026-01-12-endpoint-startup-metrics-design.md diff --git a/model-engine/docs/plans/2026-01-12-endpoint-startup-metrics-design.md b/model-engine/docs/plans/2026-01-12-endpoint-startup-metrics-design.md new file mode 100644 index 00000000..7918c41b --- /dev/null +++ b/model-engine/docs/plans/2026-01-12-endpoint-startup-metrics-design.md @@ -0,0 +1,532 @@ +# Endpoint Startup Metrics Design + +## Overview + +Instrument model endpoint startup to capture detailed timing metrics for optimization analysis. Uses OpenTelemetry for vendor-neutral instrumentation with Datadog as the visualization layer. + +## Goals + +- Understand pod scaling time for GPU endpoints +- Measure: pod initialization, model download (s5cmd), model loading, endpoint ready +- Enable aggregate analysis for optimization (p50/p95 by model, GPU type, region) +- Establish baseline metrics before implementing improvements (e.g., preloading model weights) + +## Architecture + +### Two-Component System + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Kubernetes Cluster │ +│ │ +│ ┌──────────────────────┐ ┌──────────────────────────────┐│ +│ │ K8s Event Watcher │ │ vLLM Pod ││ +│ │ (Deployment) │ │ ││ +│ │ │ │ ┌────────────────────────┐ ││ +│ │ • Watches pod events │ │ │ Startup Instrumenter │ ││ +│ │ • Emits spans: │ │ │ │ ││ +│ │ - pod_scheduled │ │ │ • Emits spans: │ ││ +│ │ - image_pulled │ │ │ - s5cmd_download │ ││ +│ │ - container_started│ │ │ - ray_cluster_init │ ││ +│ │ │ │ │ - vllm_init │ ││ +│ │ Correlates via │◄─────┼──│ - server_ready │ ││ +│ │ pod_uid │ │ │ │ ││ +│ │ │ │ └───────────┬────────────┘ ││ +│ └──────────┬───────────┘ │ │ ││ +│ │ │ │ ││ +└─────────────┼──────────────────┼──────────────┼───────────────┘ + │ │ │ + ▼ │ ▼ + ┌──────────────────────────────────────────┐ + │ OTel Collector (DaemonSet) │ + │ │ + │ Joins spans by trace_id (derived from │ + │ pod_uid) into complete startup trace │ + │ │ + │ Exports to Datadog │ + └──────────────────────────────────────────┘ +``` + +### Correlation Strategy + +- Trace ID = deterministic MD5 hash of `pod_uid` +- Both components emit spans with same trace ID +- Datadog joins them into a single trace view + +### Why Two Components + +- K8s events (scheduling, image pull) happen before container code runs +- Container instrumentation captures phases we control +- Single trace gives end-to-end visibility + +## Trace Structure + +``` +trace_id: derived from pod_uid (deterministic hash) + +K8s Event Watcher spans: +├── pod_pending [pod created → scheduled] +├── image_pull [pulling → pulled] +└── container_creating [creating → started] + +In-container spans: +└── startup (parent) + ├── s5cmd_download [download start → complete] + ├── ray_cluster_init [ray start → cluster ready] (batch only) + ├── vllm_init [from_vllm_config() call duration] + └── server_ready [server listening] +``` + +## Attributes (Low Cardinality) + +All spans include: + +| Attribute | Type | Example | Cardinality | +|-----------|------|---------|-------------| +| `endpoint_name` | string | `llama-70b-prod` | Low | +| `model_name` | string | `llama-70b` | Low | +| `gpu_type` | string | `h100`, `a100` | Low | +| `node_name` | string | `gpu-node-pool-abc123` | Medium | +| `namespace` | string | `model-endpoints` | Low | +| `num_gpus` | int | `8` | Low | +| `image_tag` | string | `v1.2.3` | Low | +| `region` | string | `us-east-1` | Low | + +## Metrics (Histograms) + +For aggregate analysis and optimization: + +``` +# Download phase +startup.download.duration_seconds {endpoint_name, model_name, region} +startup.download.throughput_mbps {endpoint_name, model_name, region} + +# vLLM init phase +startup.vllm_init.duration_seconds {endpoint_name, model_name, gpu_type, num_gpus} + +# Total startup +startup.total.duration_seconds {endpoint_name, model_name, gpu_type} + +# K8s phases +startup.pod_pending.duration_seconds {endpoint_name, namespace} +startup.image_pull.duration_seconds {endpoint_name, image_tag} +``` + +## Implementation Components + +### 1. Helm Chart Changes + +Add environment variables for telemetry context: + +```yaml +# Add to _helpers.tpl baseServiceTemplateEnv +- name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid +- name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name +- name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName +- name: GPU_TYPE + value: "${GPU_TYPE}" +- name: AWS_REGION + value: "${AWS_REGION}" +- name: OTEL_EXPORTER_OTLP_ENDPOINT + value: "${OTEL_COLLECTOR_ENDPOINT}" +``` + +Enable vLLM native request tracing: + +```yaml +# vLLM server args +args: + - "--otlp-traces-endpoint=$(OTEL_EXPORTER_OTLP_ENDPOINT)" +``` + +### 2. Startup Telemetry Module (New File) + +Location: `model_engine_server/inference/vllm/startup_telemetry.py` + +```python +import hashlib +import os +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Generator + +from opentelemetry import trace, metrics +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter +from opentelemetry.sdk.resources import Resource +from opentelemetry.trace import SpanKind, Status, StatusCode + +_tracer: trace.Tracer | None = None +_meter: metrics.Meter | None = None +_histograms: dict = {} +_context: "StartupContext | None" = None + + +@dataclass +class StartupContext: + """Runtime context for startup telemetry.""" + endpoint_name: str + model_name: str + gpu_type: str + num_gpus: int + region: str + + +def init_startup_telemetry(ctx: StartupContext) -> None: + """Initialize OTel SDK for startup instrumentation.""" + global _tracer, _meter, _histograms, _context + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + if not endpoint: + return # Telemetry disabled + + _context = ctx + + resource = Resource.create({ + "service.name": "vllm-startup", + "k8s.pod.uid": os.environ.get("POD_UID", "unknown"), + "k8s.pod.name": os.environ.get("POD_NAME", "unknown"), + "k8s.node.name": os.environ.get("NODE_NAME", "unknown"), + "endpoint_name": ctx.endpoint_name, + "model_name": ctx.model_name, + }) + + # Traces + provider = TracerProvider(resource=resource) + provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=endpoint))) + trace.set_tracer_provider(provider) + _tracer = trace.get_tracer("vllm-startup") + + # Metrics + reader = PeriodicExportingMetricReader(OTLPMetricExporter(endpoint=endpoint)) + meter_provider = MeterProvider(resource=resource, metric_readers=[reader]) + metrics.set_meter_provider(meter_provider) + _meter = metrics.get_meter("vllm-startup") + + # Create histograms + _histograms["download_duration"] = _meter.create_histogram( + "startup.download.duration_seconds", + description="Model download duration", + ) + _histograms["vllm_init_duration"] = _meter.create_histogram( + "startup.vllm_init.duration_seconds", + description="vLLM initialization duration", + ) + _histograms["total_duration"] = _meter.create_histogram( + "startup.total.duration_seconds", + description="Total startup duration", + ) + + +def derive_trace_id(pod_uid: str) -> str: + """Deterministic trace ID from pod UID for correlation.""" + return hashlib.md5(pod_uid.encode()).hexdigest()[:32] + + +@contextmanager +def startup_span(name: str, attributes: dict | None = None) -> Generator: + """Context manager for startup phase spans.""" + if not _tracer or not _context: + yield None + return + + attrs = { + "endpoint_name": _context.endpoint_name, + "model_name": _context.model_name, + "gpu_type": _context.gpu_type, + "num_gpus": _context.num_gpus, + "region": _context.region, + **(attributes or {}), + } + + with _tracer.start_as_current_span(name, kind=SpanKind.INTERNAL) as span: + for k, v in attrs.items(): + span.set_attribute(k, v) + start = time.perf_counter() + try: + yield span + span.set_status(Status(StatusCode.OK)) + except Exception as e: + span.set_status(Status(StatusCode.ERROR, str(e))) + span.record_exception(e) + raise + finally: + duration = time.perf_counter() - start + span.set_attribute("duration_seconds", duration) + + +def record_metric(name: str, value: float, extra_attrs: dict | None = None) -> None: + """Record a histogram metric.""" + if not _context or name not in _histograms: + return + + attrs = { + "endpoint_name": _context.endpoint_name, + "model_name": _context.model_name, + "gpu_type": _context.gpu_type, + **(extra_attrs or {}), + } + _histograms[name].record(value, attrs) +``` + +### 3. Integration with vllm_batch.py + +```python +# Add imports +from model_engine_server.inference.vllm.startup_telemetry import ( + StartupContext, + init_startup_telemetry, + startup_span, + record_metric, +) + +# In handle_batch_job() +async def handle_batch_job(request: CreateBatchCompletionsEngineRequest): + # Initialize telemetry with runtime context + ctx = StartupContext( + endpoint_name=os.environ.get("ENDPOINT_NAME", "unknown"), + model_name=request.model_cfg.model, + gpu_type=os.environ.get("GPU_TYPE", "unknown"), + num_gpus=request.model_cfg.num_gpus or 1, + region=os.environ.get("AWS_REGION", "unknown"), + ) + init_startup_telemetry(ctx) + + total_start = time.perf_counter() + + with startup_span("startup") as parent_span: + # Download phase + with startup_span("s5cmd_download"): + download_start = time.perf_counter() + await download_model() + record_metric("download_duration", time.perf_counter() - download_start) + + # Ray init (multi-node only) + with startup_span("ray_cluster_init"): + init_ray() + + # vLLM init + with startup_span("vllm_init"): + init_start = time.perf_counter() + engine = await init_engine(...) + record_metric("vllm_init_duration", time.perf_counter() - init_start) + + # Server ready + with startup_span("server_ready"): + total_duration = time.perf_counter() - total_start + record_metric("total_duration", total_duration) + if parent_span: + parent_span.set_attribute("total_startup_seconds", total_duration) + + # Continue to serve (vLLM native tracing takes over) + await generate_completions(engine, ...) +``` + +### 4. K8s Event Watcher (New Deployment) + +New service that watches pod lifecycle events. + +```python +# k8s_startup_watcher.py +import hashlib +from kubernetes import client, watch +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + +LABEL_SELECTOR = "app.kubernetes.io/managed-by=model-engine" + + +def derive_trace_id(pod_uid: str) -> str: + return hashlib.md5(pod_uid.encode()).hexdigest()[:32] + + +def watch_pod_events(namespace: str): + v1 = client.CoreV1Api() + w = watch.Watch() + tracer = trace.get_tracer("k8s-startup-watcher") + + pending_pods = {} # pod_uid -> timestamps + + for event in w.stream( + v1.list_namespaced_event, + namespace=namespace, + field_selector="involvedObject.kind=Pod" + ): + obj = event["object"] + pod_uid = obj.involved_object.uid + reason = obj.reason + timestamp = obj.first_timestamp + + # Filter to managed pods only + if not is_managed_pod(obj.involved_object.name, namespace): + continue + + trace_id = derive_trace_id(pod_uid) + + if reason == "Scheduled": + emit_span(tracer, "pod_scheduled", trace_id, timestamp) + elif reason == "Pulling": + pending_pods.setdefault(pod_uid, {})["pull_start"] = timestamp + elif reason == "Pulled": + emit_span(tracer, "image_pull", trace_id, timestamp, + start=pending_pods.get(pod_uid, {}).get("pull_start")) + elif reason == "Started": + emit_span(tracer, "container_started", trace_id, timestamp) +``` + +Deployment manifest: + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: startup-watcher + namespace: model-engine +spec: + replicas: 1 + selector: + matchLabels: + app: startup-watcher + template: + spec: + serviceAccountName: startup-watcher + containers: + - name: watcher + image: your-registry/startup-watcher:latest + env: + - name: OTEL_EXPORTER_OTLP_ENDPOINT + value: "localhost:4317" + - name: WATCH_NAMESPACE + value: "model-endpoints" + resources: + requests: + cpu: 100m + memory: 128Mi +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: startup-watcher +rules: + - apiGroups: [""] + resources: ["pods", "events"] + verbs: ["get", "list", "watch"] +``` + +### 5. OTel Collector (DaemonSet) + +```yaml +# otel-collector-config.yaml +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + +processors: + batch: + timeout: 10s + +exporters: + datadog: + api: + key: ${DD_API_KEY} + site: datadoghq.com + traces: + span_name_as_resource_name: true + metrics: + histograms: + mode: distributions + +service: + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [datadog] + metrics: + receivers: [otlp] + processors: [batch] + exporters: [datadog] +``` + +## vLLM Native Tracing + +vLLM has built-in OTel support for request tracing (not startup). Enable with: + +```bash +vllm serve model --otlp-traces-endpoint="localhost:4317" +``` + +This traces each inference request automatically once the server is ready. + +## Limitations + +### vLLM Init is a Black Box + +Without forking vLLM, we cannot break down `vllm_init` into sub-phases: +- Weight loading +- KV cache allocation +- CUDA graph compilation + +We can only measure the total `AsyncLLM.from_vllm_config()` duration. + +### Future Granularity Options + +1. **Accept the black box** - sufficient for initial benchmarking +2. **Parse vLLM logs** - fragile but no fork required +3. **Upstream contribution** - vLLM has open issues for startup tracing +4. **Lightweight fork** - add ~15 lines to `LLMEngine.__init__()` + +## HA Path for K8s Watcher + +Start with single replica. Path to HA: + +``` +Single Replica (v1) → Leader Election (v2) + │ │ + │ ├── Add Lease-based leader election + │ ├── Bump replicas to 2-3 + │ └── No changes to span emission logic +``` + +## Implementation Order + +1. **Helm chart changes** - Add env vars (POD_UID, GPU_TYPE, etc.) +2. **startup_telemetry.py** - New module with OTel helpers +3. **vllm_batch.py integration** - Wrap startup phases +4. **OTel Collector DaemonSet** - Deploy to cluster +5. **K8s Event Watcher** - Deploy as separate service +6. **Datadog dashboards** - Build startup metrics views + +## Example Datadog Queries + +``` +# P95 total startup by model +p95:startup.total.duration_seconds{*} by {model_name} + +# Download throughput by region +avg:startup.download.throughput_mbps{*} by {region} + +# vLLM init time: H100 vs A100 +avg:startup.vllm_init.duration_seconds{*} by {gpu_type} + +# Slowest endpoints +top10:startup.total.duration_seconds{*} by {endpoint_name} +``` From df2af85c4392d7b1f097b4796f692e3108d575bd Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 30 Jan 2026 19:41:06 -0800 Subject: [PATCH 2/4] Add .worktrees to gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 276b0676..fc9380a7 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,7 @@ venv.bak/ node_modules/ site/ + +# Git worktrees +.worktrees/ +.worktrees/ From 516c235e704fbeb974bf3655ba193eec184a18a4 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 30 Jan 2026 20:03:02 -0800 Subject: [PATCH 3/4] Add support for CPU-only vLLM endpoints Changes: - Allow gpus=0 with gpu_type=None for CPU-only endpoints - Add validate_cpu_only_framework to ensure only vLLM supports CPU - Update validate_num_shards to require num_shards=0 for CPU-only - Set vLLM args for CPU mode: device=cpu, dtype=float32, tensor_parallel_size=1 - Add 'device' field to VLLMModelConfig - Fix type narrowing in execute() for mypy Co-Authored-By: Claude Opus 4.5 --- .../common/dtos/llms/vllm.py | 5 ++ .../use_cases/llm_model_endpoint_use_cases.py | 69 +++++++++++++++---- 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index 473af057..98bcd0ad 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -123,6 +123,11 @@ class VLLMModelConfig(BaseModel): description="Data type for model weights and activations. The 'auto' option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.", ) + device: Optional[str] = Field( + None, + description="Device to run on. Options: 'auto', 'cuda', 'cpu'. Default is 'auto'.", + ) + seed: Optional[int] = Field( None, description="Random seed for reproducibility.", diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 10e1979b..cde2cf8d 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -139,10 +139,16 @@ OPENAI_CHAT_COMPLETION_PATH = "/v1/chat/completions" CHAT_TEMPLATE_MAX_LENGTH = 10_000 -CHAT_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM, LLMInferenceFramework.SGLANG] +CHAT_SUPPORTED_INFERENCE_FRAMEWORKS = [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.SGLANG, +] OPENAI_COMPLETION_PATH = "/v1/completions" -OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM, LLMInferenceFramework.SGLANG] +OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS = [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.SGLANG, +] LLM_METADATA_KEY = "_llm" RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] @@ -270,6 +276,14 @@ def validate_model_name(_model_name: str, _inference_framework: LLMInferenceFram def validate_num_shards( num_shards: int, inference_framework: LLMInferenceFramework, gpus: int ) -> None: + # CPU-only endpoints (gpus=0) must have num_shards=0 + if gpus == 0: + if num_shards != 0: + raise ObjectHasInvalidValueException( + f"CPU-only endpoints must have num_shards=0, got {num_shards}." + ) + return + if inference_framework == LLMInferenceFramework.DEEPSPEED: if num_shards <= 1: raise ObjectHasInvalidValueException("DeepSpeed requires more than 1 GPU.") @@ -283,6 +297,20 @@ def validate_num_shards( ) +# Inference frameworks that support CPU-only execution +_CPU_SUPPORTED_FRAMEWORKS = {LLMInferenceFramework.VLLM} + + +def validate_cpu_only_framework(inference_framework: LLMInferenceFramework, gpus: int) -> None: + """Validate that the inference framework supports CPU-only execution.""" + if gpus == 0 and inference_framework not in _CPU_SUPPORTED_FRAMEWORKS: + raise ObjectHasInvalidValueException( + f"CPU-only endpoints (gpus=0) are only supported for frameworks: " + f"{[f.value for f in _CPU_SUPPORTED_FRAMEWORKS]}. " + f"Got: {inference_framework.value}" + ) + + def validate_quantization( quantize: Optional[Quantization], inference_framework: LLMInferenceFramework ) -> None: @@ -931,7 +959,13 @@ def _create_vllm_bundle_command( subcommands.append(ray_cmd) if not is_worker: - vllm_args.tensor_parallel_size = num_shards + # CPU-only mode: num_shards=0 + if num_shards == 0: + vllm_args.device = "cpu" + vllm_args.dtype = "float32" # CPU doesn't support float16 well + vllm_args.tensor_parallel_size = 1 # vLLM uses 1 for single device + else: + vllm_args.tensor_parallel_size = num_shards if vllm_args.gpu_memory_utilization is not None: vllm_args.enforce_eager = True @@ -1276,15 +1310,19 @@ async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: await _fill_hardware_info(self.llm_artifact_gateway, request) - if not ( - request.gpus - and request.gpu_type - and request.cpus - and request.memory - and request.storage - and request.nodes_per_worker + # Validate hardware info is present + # Note: gpus=0 and gpu_type=None are valid for CPU-only endpoints + if ( + request.gpus is None + or request.cpus is None + or request.memory is None + or request.storage is None + or request.nodes_per_worker is None ): raise RuntimeError("Some hardware info is missing unexpectedly.") + # For GPU endpoints, gpu_type must be provided + if request.gpus > 0 and request.gpu_type is None: + raise RuntimeError("gpu_type is required when gpus > 0.") validate_deployment_resources( min_workers=request.min_workers, max_workers=request.max_workers, @@ -1303,6 +1341,7 @@ async def execute( validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) validate_num_shards(request.num_shards, request.inference_framework, request.gpus) + validate_cpu_only_framework(request.inference_framework, request.gpus) validate_quantization(request.quantize, request.inference_framework) validate_chat_template(request.chat_template_override, request.inference_framework) @@ -3324,9 +3363,14 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl async def _fill_hardware_info( llm_artifact_gateway: LLMArtifactGateway, request: CreateLLMModelEndpointV1Request ): + # For CPU-only endpoints (gpus=0), gpu_type should be None + # gpu_type is only required when gpus > 0 + gpu_type_required = request.gpus is not None and request.gpus > 0 + gpu_type_missing = gpu_type_required and request.gpu_type is None + if ( request.gpus is None - or request.gpu_type is None + or gpu_type_missing or request.cpus is None or request.memory is None or request.storage is None @@ -3341,7 +3385,8 @@ async def _fill_hardware_info( and request.nodes_per_worker is None ): raise ObjectHasInvalidValueException( - "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing." + "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing. " + "For CPU-only endpoints, set gpus=0 and gpu_type=None." ) checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) hardware_info = await _infer_hardware( From 68e4fbfa6332bbe0e46d1210e6726977fef19a42 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 30 Jan 2026 20:08:51 -0800 Subject: [PATCH 4/4] Add tests for CPU-only endpoint validation Add tests for validate_num_shards and validate_cpu_only_framework to cover the new CPU-only endpoint support. Co-Authored-By: Claude Opus 4.5 --- .../tests/unit/domain/test_llm_use_cases.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index fbcf543c..85f336fa 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -63,6 +63,8 @@ validate_and_update_completion_params, validate_chat_template, validate_checkpoint_files, + validate_cpu_only_framework, + validate_num_shards, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -3038,3 +3040,37 @@ def test_validate_chat_template(): with pytest.raises(ObjectHasInvalidValueException): validate_chat_template(good_chat_template, LLMInferenceFramework.DEEPSPEED) + + +def test_validate_num_shards_cpu_only(): + # CPU-only (gpus=0) must have num_shards=0 + validate_num_shards(0, LLMInferenceFramework.VLLM, 0) # Should pass + + # CPU-only with non-zero num_shards should fail + with pytest.raises(ObjectHasInvalidValueException): + validate_num_shards(1, LLMInferenceFramework.VLLM, 0) + + # GPU endpoint: num_shards must equal gpus + validate_num_shards(2, LLMInferenceFramework.VLLM, 2) # Should pass + + with pytest.raises(ObjectHasInvalidValueException): + validate_num_shards(2, LLMInferenceFramework.VLLM, 4) # Mismatch + + +def test_validate_cpu_only_framework(): + # vLLM supports CPU-only + validate_cpu_only_framework(LLMInferenceFramework.VLLM, 0) # Should pass + + # vLLM with GPUs should also pass (not CPU-only) + validate_cpu_only_framework(LLMInferenceFramework.VLLM, 2) # Should pass + validate_cpu_only_framework(LLMInferenceFramework.DEEPSPEED, 4) # Should pass + + # Non-vLLM frameworks don't support CPU-only + with pytest.raises(ObjectHasInvalidValueException): + validate_cpu_only_framework(LLMInferenceFramework.DEEPSPEED, 0) + + with pytest.raises(ObjectHasInvalidValueException): + validate_cpu_only_framework(LLMInferenceFramework.TEXT_GENERATION_INFERENCE, 0) + + with pytest.raises(ObjectHasInvalidValueException): + validate_cpu_only_framework(LLMInferenceFramework.TENSORRT_LLM, 0)