mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 16:27:14 +08:00
142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Compatibility wrapper for DeepGEMM API changes.
|
|
|
|
Users of vLLM should always import **only** these wrappers.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import importlib
|
|
from typing import Any, Callable, NoReturn
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.utils import cuda_get_device_properties, has_deep_gemm
|
|
|
|
|
|
@functools.cache
|
|
def is_blackwell_deep_gemm_used() -> bool:
|
|
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
|
|
Blackwell-class GPU.
|
|
"""
|
|
|
|
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
|
|
and _per_block_cast_impl is not None):
|
|
return False
|
|
|
|
return cuda_get_device_properties(0, ("major", ))[0] == 10
|
|
|
|
|
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
|
"""Placeholder for unavailable DeepGEMM backend."""
|
|
raise RuntimeError(
|
|
"DeepGEMM backend is not available. Please install the `deep_gemm` "
|
|
"package to enable FP8 kernels.")
|
|
|
|
|
|
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
|
|
"""Return the *new* symbol if it exists, otherwise the *old* one."""
|
|
if hasattr(module, new):
|
|
return getattr(module, new)
|
|
if hasattr(module, old):
|
|
return getattr(module, old)
|
|
return None
|
|
|
|
|
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
|
_grouped_impl: Callable[..., Any] | None = None
|
|
_grouped_masked_impl: Callable[..., Any] | None = None
|
|
_per_block_cast_impl: Callable[..., Any] | None = None
|
|
|
|
|
|
def _lazy_init() -> None:
|
|
"""Import deep_gemm and resolve symbols on first use."""
|
|
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
|
|
_per_block_cast_impl
|
|
|
|
# fast path
|
|
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
|
or _grouped_masked_impl is not None
|
|
or _per_block_cast_impl is not None):
|
|
return
|
|
|
|
if not has_deep_gemm():
|
|
return
|
|
|
|
_dg = importlib.import_module("deep_gemm")
|
|
|
|
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
|
|
"gemm_fp8_fp8_bf16_nt")
|
|
_grouped_impl = _resolve_symbol(
|
|
_dg, "m_grouped_fp8_gemm_nt_contiguous",
|
|
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
|
|
_grouped_masked_impl = _resolve_symbol(
|
|
_dg, "fp8_m_grouped_gemm_nt_masked",
|
|
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
|
|
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
|
|
try:
|
|
_math_mod = importlib.import_module(
|
|
"deep_gemm.utils.math") # type: ignore
|
|
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
|
|
None)
|
|
except ModuleNotFoundError:
|
|
_per_block_cast_impl = None
|
|
|
|
|
|
def fp8_gemm_nt(*args, **kwargs):
|
|
_lazy_init()
|
|
if _fp8_gemm_nt_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _fp8_gemm_nt_impl(*args, **kwargs)
|
|
|
|
|
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
|
_lazy_init()
|
|
if _grouped_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _grouped_impl(*args, **kwargs)
|
|
|
|
|
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
|
_lazy_init()
|
|
if _grouped_masked_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _grouped_masked_impl(*args, **kwargs)
|
|
|
|
|
|
def per_block_cast_to_fp8(x, *args, **kwargs):
|
|
_lazy_init()
|
|
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
|
return _per_block_cast_impl(x, use_ue8m0=True)
|
|
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
|
|
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
|
|
return _pbcf(x, *args, **kwargs)
|
|
|
|
|
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
"""Return a global difference metric for unit tests.
|
|
|
|
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
|
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
|
|
every element, we compute a cosine-style similarity over the whole tensor
|
|
and report ``1 - sim``. Once kernel accuracy improves this helper can be
|
|
removed.
|
|
"""
|
|
|
|
x, y = x.double(), y.double()
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return 1 - sim
|
|
|
|
|
|
__all__ = [
|
|
"calc_diff",
|
|
"fp8_gemm_nt",
|
|
"m_grouped_fp8_gemm_nt_contiguous",
|
|
"fp8_m_grouped_gemm_nt_masked",
|
|
"per_block_cast_to_fp8",
|
|
"is_blackwell_deep_gemm_used",
|
|
]
|