Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ data:
host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local"
port:
number: 80
${MCP_TIMEOUT}
{{- end }}
{{- if .Values.destinationrule.enabled }}
destination-rule.yaml: |-
Expand Down
4 changes: 4 additions & 0 deletions model-engine/model_engine_server/common/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH",
"LAUNCH_SERVICE_TEMPLATE_FOLDER",
"LOCAL",
"MCP_TIMEOUT_SECONDS",
"SKIP_AUTH",
"WORKSPACE",
"get_boolean_env_var",
Expand Down Expand Up @@ -78,3 +79,6 @@ def get_boolean_env_var(name: str) -> bool:
GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND")
if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules:
raise ValueError("GIT_TAG environment variable must be set")

MCP_TIMEOUT_SECONDS: int = int(os.environ.get("MCP_TIMEOUT_SECONDS", "30"))
"""Timeout in seconds for MCP server Istio VirtualService. Defaults to 30 seconds."""
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.model_endpoints import BrokerName, BrokerType
from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest
from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG
from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG, MCP_TIMEOUT_SECONDS
from model_engine_server.common.resource_limits import (
FORWARDER_CPU_USAGE,
FORWARDER_MEMORY_USAGE,
Expand Down Expand Up @@ -382,6 +382,7 @@ class VirtualServiceArguments(_BaseEndpointArguments):
"""Keyword-arguments for substituting into virtual-service templates."""

DNS_HOST_DOMAIN: str
MCP_TIMEOUT: str # Defaults to 30s, only applies to MCP servers


class LwsServiceEntryArguments(_BaseEndpointArguments):
Expand Down Expand Up @@ -1361,6 +1362,21 @@ def get_endpoint_resource_arguments_from_request(
SERVICE_NAME_OVERRIDE=service_name_override,
)
elif endpoint_resource_name == "virtual-service":
# MCP servers use passthrough forwarder and have routes containing /mcp
timeout = ""
if isinstance(flavor, RunnableImageLike) and flavor.forwarder_type == "passthrough":
all_routes = []
if flavor.predict_route:
all_routes.append(flavor.predict_route)
if flavor.routes:
all_routes.extend(flavor.routes)
if flavor.extra_routes:
all_routes.extend(flavor.extra_routes)
is_mcp_server = any("/mcp" in route.lower() for route in all_routes)

if is_mcp_server:
timeout = f"timeout: {MCP_TIMEOUT_SECONDS}s"

return VirtualServiceArguments(
# Base resource arguments
RESOURCE_NAME=k8s_resource_group_name,
Expand All @@ -1373,6 +1389,7 @@ def get_endpoint_resource_arguments_from_request(
OWNER=owner,
GIT_TAG=GIT_TAG,
DNS_HOST_DOMAIN=infra_config().dns_host_domain,
MCP_TIMEOUT=timeout,
)
elif endpoint_resource_name == "destination-rule":
return DestinationRuleArguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from kubernetes_asyncio.client.rest import ApiException
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest
from model_engine_server.common.env_vars import GIT_TAG
from model_engine_server.common.env_vars import GIT_TAG, MCP_TIMEOUT_SECONDS
from model_engine_server.domain.entities import (
ModelBundle,
ModelEndpointConfig,
ModelEndpointType,
ModelEndpointUserConfigState,
RunnableImageFlavor,
)
from model_engine_server.domain.exceptions import EndpointResourceInfraException
from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import (
Expand All @@ -28,6 +29,7 @@
DictStrInt,
DictStrStr,
ResourceArguments,
get_endpoint_resource_arguments_from_request,
)
from tests.unit.infra.gateways.k8s_fake_objects import FakeK8sDeploymentContainer, FakeK8sEnvVar

Expand Down Expand Up @@ -976,3 +978,85 @@ def test_add_pod_metadata_env_to_container():

node_name_env = next(e for e in container["env"] if e["name"] == "NODE_NAME")
assert node_name_env["valueFrom"]["fieldRef"]["fieldPath"] == "spec.nodeName"


def test_virtual_service_mcp_timeout_mcp_server(
create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest,
):
"""Test that MCP servers get timeout set in VirtualService arguments."""
# Modify the bundle flavor to be an MCP server (passthrough forwarder with /mcp route)
build_endpoint_request = create_resources_request_sync_runnable_image.build_endpoint_request
model_bundle = build_endpoint_request.model_endpoint_record.current_model_bundle
assert isinstance(model_bundle.flavor, RunnableImageFlavor)

# Create a new flavor with passthrough forwarder and /mcp route
mcp_flavor = RunnableImageFlavor(
flavor="runnable_image",
repository=model_bundle.flavor.repository,
tag=model_bundle.flavor.tag,
command=model_bundle.flavor.command,
predict_route="/mcp/predict", # Contains /mcp
healthcheck_route=model_bundle.flavor.healthcheck_route,
env=model_bundle.flavor.env,
protocol=model_bundle.flavor.protocol,
readiness_initial_delay_seconds=model_bundle.flavor.readiness_initial_delay_seconds,
forwarder_type="passthrough", # Required for MCP detection
)

# Create a new bundle with MCP flavor
mcp_bundle = ModelBundle(
id=model_bundle.id,
name=model_bundle.name,
created_by=model_bundle.created_by,
owner=model_bundle.owner,
created_at=model_bundle.created_at,
model_artifact_ids=model_bundle.model_artifact_ids,
metadata=model_bundle.metadata,
flavor=mcp_flavor,
location=model_bundle.location,
requirements=model_bundle.requirements,
env_params=model_bundle.env_params,
packaging_type=model_bundle.packaging_type,
app_config=model_bundle.app_config,
)

# Update the request with MCP bundle
build_endpoint_request.model_endpoint_record.current_model_bundle = mcp_bundle

# Derive k8s_resource_group_name from endpoint_id
endpoint_id = build_endpoint_request.model_endpoint_record.id
k8s_resource_group_name = f"launch-endpoint-id-{endpoint_id}".replace("_", "-")

# Get virtual service arguments
args = get_endpoint_resource_arguments_from_request(
k8s_resource_group_name=k8s_resource_group_name,
request=create_resources_request_sync_runnable_image,
sqs_queue_name="test_queue",
sqs_queue_url="https://test_queue",
endpoint_resource_name="virtual-service",
)

# Verify MCP_TIMEOUT is set correctly
assert args["MCP_TIMEOUT"] == f"timeout: {MCP_TIMEOUT_SECONDS}s"


def test_virtual_service_mcp_timeout_non_mcp_server(
create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest,
):
"""Test that non-MCP servers don't get timeout set (use Istio default)."""
# Derive k8s_resource_group_name from endpoint_id
build_endpoint_request = create_resources_request_sync_runnable_image.build_endpoint_request
endpoint_id = build_endpoint_request.model_endpoint_record.id
k8s_resource_group_name = f"launch-endpoint-id-{endpoint_id}".replace("_", "-")

# Get virtual service arguments for a regular (non-MCP) server
args = get_endpoint_resource_arguments_from_request(
k8s_resource_group_name=k8s_resource_group_name,
request=create_resources_request_sync_runnable_image,
sqs_queue_name="test_queue",
sqs_queue_url="https://test_queue",
endpoint_resource_name="virtual-service",
)

# Verify MCP_TIMEOUT is empty (use Istio default)
assert args["MCP_TIMEOUT"] == ""