diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 09a12a8c11c5..169b083017e4 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -13,7 +13,8 @@ from typing import Any, Callable, NoReturn import torch import vllm.envs as envs -from vllm.utils import cuda_get_device_properties, has_deep_gemm +from vllm.platforms import current_platform +from vllm.utils import has_deep_gemm @functools.cache @@ -21,12 +22,15 @@ def is_blackwell_deep_gemm_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM on a Blackwell-class GPU. """ - - if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() - and _per_block_cast_impl is not None): + if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()): return False - return cuda_get_device_properties(0, ("major", ))[0] == 10 + _lazy_init() + if _per_block_cast_impl is None: + return False + + return (current_platform.is_cuda() + and current_platform.is_device_capability(100)) def _missing(*_: Any, **__: Any) -> NoReturn: