Skip to content

Commit 83319b4

Browse files
authored
[Compile] Fix torch warning TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled (#29897)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 56037df commit 83319b4

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

tests/v1/e2e/test_async_scheduling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def run_tests(
124124
with monkeypatch.context() as m:
125125
# avoid precision errors
126126
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
127+
# lock matmul precision to full FP32
128+
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
127129
# m.setenv("VLLM_BATCH_INVARIANT", "1")
128130
outputs: list[tuple[str, list, list]] = []
129131
for n, (

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
VLLM_MM_INPUT_CACHE_GIB: int = 4
7676
VLLM_TARGET_DEVICE: str = "cuda"
7777
VLLM_MAIN_CUDA_VERSION: str = "12.9"
78+
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
7879
MAX_JOBS: str | None = None
7980
NVCC_THREADS: str | None = None
8081
VLLM_USE_PRECOMPILED: bool = False
@@ -452,6 +453,14 @@ def get_vllm_port() -> int | None:
452453
# Main CUDA version of vLLM. This follows PyTorch but can be overridden.
453454
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
454455
or "12.9",
456+
# Controls PyTorch float32 matmul precision mode within vLLM workers.
457+
# Valid options mirror torch.set_float32_matmul_precision
458+
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
459+
"VLLM_FLOAT32_MATMUL_PRECISION",
460+
"highest",
461+
["highest", "high", "medium"],
462+
case_sensitive=False,
463+
),
455464
# Maximum number of compilation jobs to run in parallel.
456465
# By default this is the number of CPUs
457466
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def __init__(
7979
is_driver_worker=is_driver_worker,
8080
)
8181

82+
# configure float32 matmul precision according to vLLM env.
83+
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
84+
torch.set_float32_matmul_precision(precision)
85+
8286
if self.model_config.trust_remote_code:
8387
# note: lazy import to avoid importing torch before initializing
8488
from vllm.utils.import_utils import init_cached_hf_modules

0 commit comments

Comments
 (0)