From 83319b44c26af45de4753c74f55a07df8c637a25 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:40:37 -0500 Subject: [PATCH] [Compile] Fix torch warning `TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled` (#29897) Signed-off-by: yewentao256 --- tests/v1/e2e/test_async_scheduling.py | 2 ++ vllm/envs.py | 9 +++++++++ vllm/v1/worker/gpu_worker.py | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 945276376d665..838d05f0486c1 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -124,6 +124,8 @@ def run_tests( with monkeypatch.context() as m: # avoid precision errors m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + # lock matmul precision to full FP32 + m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest") # m.setenv("VLLM_BATCH_INVARIANT", "1") outputs: list[tuple[str, list, list]] = [] for n, ( diff --git a/vllm/envs.py b/vllm/envs.py index 91d1b01076b11..bda9e6e423356 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" + VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False @@ -452,6 +453,14 @@ environment_variables: dict[str, Callable[[], Any]] = { # Main CUDA version of vLLM. This follows PyTorch but can be overridden. "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() or "12.9", + # Controls PyTorch float32 matmul precision mode within vLLM workers. + # Valid options mirror torch.set_float32_matmul_precision + "VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices( + "VLLM_FLOAT32_MATMUL_PRECISION", + "highest", + ["highest", "high", "medium"], + case_sensitive=False, + ), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a46ec2bd118fe..24a3533a169f0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -79,6 +79,10 @@ class Worker(WorkerBase): is_driver_worker=is_driver_worker, ) + # configure float32 matmul precision according to vLLM env. + precision = envs.VLLM_FLOAT32_MATMUL_PRECISION + torch.set_float32_matmul_precision(precision) + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules