mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 00:45:01 +08:00
325 lines
12 KiB
Python
325 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Compatibility wrapper for DeepGEMM API changes.
|
|
|
|
Users of vLLM should always import **only** these wrappers.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import importlib
|
|
import os
|
|
from typing import Any, Callable, NoReturn
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import cdiv, has_deep_gemm
|
|
|
|
|
|
@functools.cache
|
|
def is_deep_gemm_supported() -> bool:
|
|
"""Return ``True`` if DeepGEMM is supported on the current platform.
|
|
Currently, only Hopper and Blackwell GPUs are supported.
|
|
"""
|
|
is_supported_arch = current_platform.is_cuda() and (
|
|
current_platform.is_device_capability(90)
|
|
or current_platform.is_device_capability(100))
|
|
return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
|
and not envs.VLLM_USE_FLASHINFER_MOE_FP8)
|
|
|
|
|
|
@functools.cache
|
|
def is_deep_gemm_e8m0_used() -> bool:
|
|
"""Return ``True`` if vLLM is configured to use DeepGEMM "
|
|
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
|
"""
|
|
if not is_deep_gemm_supported():
|
|
logger.debug_once(
|
|
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
|
|
return False
|
|
|
|
_lazy_init()
|
|
|
|
if _fp8_gemm_nt_impl is None:
|
|
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
|
return False
|
|
|
|
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
|
|
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
|
|
return False
|
|
|
|
if current_platform.is_device_capability(100) and \
|
|
envs.VLLM_USE_DEEP_GEMM_E8M0:
|
|
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
|
return True
|
|
|
|
if current_platform.is_device_capability(90) and \
|
|
envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
|
|
logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
|
|
return True
|
|
|
|
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
|
return False
|
|
|
|
|
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
|
"""Placeholder for unavailable DeepGEMM backend."""
|
|
raise RuntimeError(
|
|
"DeepGEMM backend is not available. Please install the `deep_gemm` "
|
|
"package to enable FP8 kernels.")
|
|
|
|
|
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
|
_grouped_impl: Callable[..., Any] | None = None
|
|
_grouped_masked_impl: Callable[..., Any] | None = None
|
|
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
|
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
|
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
|
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
|
|
|
|
|
def _lazy_init() -> None:
|
|
"""Import deep_gemm and resolve symbols on first use."""
|
|
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
|
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
|
global _get_paged_mqa_logits_metadata_impl
|
|
global _get_mn_major_tma_aligned_tensor_impl
|
|
|
|
# fast path
|
|
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
|
or _grouped_masked_impl is not None
|
|
or _fp8_mqa_logits_impl is not None
|
|
or _fp8_paged_mqa_logits_impl is not None
|
|
or _get_paged_mqa_logits_metadata_impl is not None):
|
|
return
|
|
|
|
if not has_deep_gemm():
|
|
return
|
|
|
|
# Set up deep_gemm cache path
|
|
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
|
|
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
|
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
|
envs.VLLM_CACHE_ROOT, "deep_gemm")
|
|
|
|
_dg = importlib.import_module("deep_gemm")
|
|
|
|
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
|
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
|
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
|
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
|
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
|
_get_paged_mqa_logits_metadata_impl = getattr(
|
|
_dg, "get_paged_mqa_logits_metadata", None)
|
|
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
|
_dg, "get_mn_major_tma_aligned_tensor", None)
|
|
|
|
|
|
def get_num_sms() -> int:
|
|
_lazy_init()
|
|
_dg = importlib.import_module("deep_gemm")
|
|
return int(_dg.get_num_sms())
|
|
|
|
|
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
|
_lazy_init()
|
|
if _get_mn_major_tma_aligned_tensor_impl is None:
|
|
return _missing()
|
|
return _get_mn_major_tma_aligned_tensor_impl(x)
|
|
|
|
|
|
def fp8_gemm_nt(*args, **kwargs):
|
|
_lazy_init()
|
|
if _fp8_gemm_nt_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _fp8_gemm_nt_impl(*args,
|
|
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
|
**kwargs)
|
|
|
|
|
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
|
_lazy_init()
|
|
if _grouped_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _grouped_impl(*args,
|
|
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
|
**kwargs)
|
|
|
|
|
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
|
_lazy_init()
|
|
if _grouped_masked_impl is None:
|
|
return _missing(*args, **kwargs)
|
|
return _grouped_masked_impl(
|
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
|
|
|
|
|
|
def fp8_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv: tuple[torch.Tensor, torch.Tensor],
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
|
|
|
Args:
|
|
q: Query tensor of shape [M, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
|
[N, 1]) with dtype `torch.float32`.
|
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
|
|
Returns:
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
"""
|
|
_lazy_init()
|
|
if _fp8_mqa_logits_impl is None:
|
|
return _missing()
|
|
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
|
def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int,
|
|
num_sms: int) -> torch.Tensor:
|
|
"""Build scheduling metadata for paged MQA logits.
|
|
|
|
Args:
|
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
|
per batch element.
|
|
block_size: KV-cache block size in tokens (e.g., 64).
|
|
num_sms: Number of SMs available. 132 for Hopper
|
|
|
|
Returns:
|
|
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
|
schedule work across SMs.
|
|
"""
|
|
_lazy_init()
|
|
if _get_paged_mqa_logits_metadata_impl is None:
|
|
return _missing()
|
|
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size,
|
|
num_sms)
|
|
|
|
|
|
def fp8_paged_mqa_logits(
|
|
q_fp8: torch.Tensor,
|
|
kv_cache_fp8: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
schedule_metadata: torch.Tensor,
|
|
max_model_len: int,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits using paged KV-cache.
|
|
|
|
Args:
|
|
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
|
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
|
4 bytes per (block,pos) store the `float` dequant scale.
|
|
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
|
for each batch element.
|
|
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
|
block indices to physical blocks in the paged cache.
|
|
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
|
used to distribute work across SMs.
|
|
max_model_len: Maximum sequence length used to size the logits output.
|
|
|
|
Returns:
|
|
Logits tensor of shape [B * next_n, max_model_len], dtype
|
|
`torch.float32`.
|
|
"""
|
|
_lazy_init()
|
|
if _fp8_paged_mqa_logits_impl is None:
|
|
return _missing()
|
|
return _fp8_paged_mqa_logits_impl(q_fp8,
|
|
kv_cache_fp8,
|
|
weights,
|
|
context_lens,
|
|
block_tables,
|
|
schedule_metadata,
|
|
max_model_len,
|
|
clean_logits=True)
|
|
|
|
|
|
def _ceil_to_ue8m0(x: torch.Tensor):
|
|
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
|
|
|
|
|
def _align(x: int, y: int) -> int:
|
|
return cdiv(x, y) * y
|
|
|
|
|
|
DEFAULT_BLOCK_SIZE = [128, 128]
|
|
|
|
|
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
|
def per_block_cast_to_fp8(
|
|
x: torch.Tensor,
|
|
block_size: list[int] = DEFAULT_BLOCK_SIZE,
|
|
use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.dim() == 2
|
|
m, n = x.shape
|
|
block_m, block_n = block_size
|
|
x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:m, :n] = x
|
|
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
sf = x_amax / 448.0
|
|
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
|
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
|
x_view.size(0), x_view.size(2))
|
|
|
|
|
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
"""Return a global difference metric for unit tests.
|
|
|
|
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
|
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
|
|
every element, we compute a cosine-style similarity over the whole tensor
|
|
and report ``1 - sim``. Once kernel accuracy improves this helper can be
|
|
removed.
|
|
"""
|
|
|
|
x, y = x.double(), y.double()
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return 1 - sim
|
|
|
|
|
|
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
|
weight: torch.Tensor):
|
|
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
|
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
|
|
|
|
|
__all__ = [
|
|
"calc_diff",
|
|
"fp8_gemm_nt",
|
|
"m_grouped_fp8_gemm_nt_contiguous",
|
|
"fp8_m_grouped_gemm_nt_masked",
|
|
"fp8_mqa_logits",
|
|
"fp8_paged_mqa_logits",
|
|
"get_paged_mqa_logits_metadata",
|
|
"per_block_cast_to_fp8",
|
|
"is_deep_gemm_e8m0_used",
|
|
"is_deep_gemm_supported",
|
|
"get_num_sms",
|
|
"should_use_deepgemm_for_fp8_linear",
|
|
"get_col_major_tma_aligned_tensor",
|
|
]
|