From e69a92a1cea23b36803caac2d251d906789eed1d Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 22 Jul 2025 02:36:18 -0400 Subject: [PATCH] [Bug] DeepGemm: Fix Cuda Init Error (#21312) Signed-off-by: yewentao256 --- vllm/utils/deep_gemm.py | 52 ++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 8b5713e02c950..09a12a8c11c5d 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: return None -if not has_deep_gemm(): - _fp8_gemm_nt_impl: Callable[..., Any] | None = None - _grouped_impl: Callable[..., Any] | None = None - _grouped_masked_impl: Callable[..., Any] | None = None - _per_block_cast_impl: Callable[..., Any] | None = None -else: - _dg = importlib.import_module("deep_gemm") # type: ignore +_fp8_gemm_nt_impl: Callable[..., Any] | None = None +_grouped_impl: Callable[..., Any] | None = None +_grouped_masked_impl: Callable[..., Any] | None = None +_per_block_cast_impl: Callable[..., Any] | None = None - _fp8_gemm_nt_impl = _resolve_symbol( - _dg, - "fp8_gemm_nt", - "gemm_fp8_fp8_bf16_nt", - ) + +def _lazy_init() -> None: + """Import deep_gemm and resolve symbols on first use.""" + global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \ + _per_block_cast_impl + + # fast path + if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None + or _grouped_masked_impl is not None + or _per_block_cast_impl is not None): + return + + if not has_deep_gemm(): + return + + _dg = importlib.import_module("deep_gemm") + + _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt", + "gemm_fp8_fp8_bf16_nt") _grouped_impl = _resolve_symbol( - _dg, - "m_grouped_fp8_gemm_nt_contiguous", - "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous", - ) + _dg, "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous") _grouped_masked_impl = _resolve_symbol( - _dg, - "fp8_m_grouped_gemm_nt_masked", - "m_grouped_gemm_fp8_fp8_bf16_nt_masked", - ) - + _dg, "fp8_m_grouped_gemm_nt_masked", + "m_grouped_gemm_fp8_fp8_bf16_nt_masked") # Try to get per_token_cast_to_fp8 from DeepGEMM math utils. try: _math_mod = importlib.import_module( @@ -80,24 +86,28 @@ else: def fp8_gemm_nt(*args, **kwargs): + _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) return _fp8_gemm_nt_impl(*args, **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): + _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) return _grouped_impl(*args, **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): + _lazy_init() if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl(*args, **kwargs) def per_block_cast_to_fp8(x, *args, **kwargs): + _lazy_init() if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): return _per_block_cast_impl(x, use_ue8m0=True) # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils