mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 01:25:01 +08:00
[Compile] Fix torch warning TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled (#29897)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
56037dfa2f
commit
83319b44c2
@ -124,6 +124,8 @@ def run_tests(
|
||||
with monkeypatch.context() as m:
|
||||
# avoid precision errors
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# lock matmul precision to full FP32
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
outputs: list[tuple[str, list, list]] = []
|
||||
for n, (
|
||||
|
||||
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
VLLM_MAIN_CUDA_VERSION: str = "12.9"
|
||||
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
|
||||
MAX_JOBS: str | None = None
|
||||
NVCC_THREADS: str | None = None
|
||||
VLLM_USE_PRECOMPILED: bool = False
|
||||
@ -452,6 +453,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Main CUDA version of vLLM. This follows PyTorch but can be overridden.
|
||||
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
|
||||
or "12.9",
|
||||
# Controls PyTorch float32 matmul precision mode within vLLM workers.
|
||||
# Valid options mirror torch.set_float32_matmul_precision
|
||||
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
|
||||
"VLLM_FLOAT32_MATMUL_PRECISION",
|
||||
"highest",
|
||||
["highest", "high", "medium"],
|
||||
case_sensitive=False,
|
||||
),
|
||||
# Maximum number of compilation jobs to run in parallel.
|
||||
# By default this is the number of CPUs
|
||||
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
|
||||
|
||||
@ -79,6 +79,10 @@ class Worker(WorkerBase):
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
# configure float32 matmul precision according to vLLM env.
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user