mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
493 lines
15 KiB
Python
493 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Compatibility wrapper for FlashInfer API changes.
|
|
|
|
Users of vLLM should always import **only** these wrappers.
|
|
"""
|
|
|
|
import contextlib
|
|
import functools
|
|
import importlib
|
|
import importlib.util
|
|
import os
|
|
import shutil
|
|
from collections.abc import Callable
|
|
from typing import Any, NoReturn
|
|
|
|
import requests
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# This is the storage path for the cubins, it can be replaced
|
|
# with a local path for testing.
|
|
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
|
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
|
"FLASHINFER_CUBINS_REPOSITORY",
|
|
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer_cubin() -> bool:
|
|
"""Return `True` if flashinfer-cubin package is available."""
|
|
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
|
return True
|
|
if importlib.util.find_spec("flashinfer_cubin") is not None:
|
|
return True
|
|
logger.debug_once("flashinfer-cubin package was not found")
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer() -> bool:
|
|
"""Return `True` if flashinfer-python package is available."""
|
|
# Use find_spec to check if the module exists without importing it
|
|
# This avoids potential CUDA initialization side effects
|
|
if importlib.util.find_spec("flashinfer") is None:
|
|
logger.debug_once("FlashInfer unavailable since package was not found")
|
|
return False
|
|
# When not using flashinfer cubin,
|
|
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
|
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
|
|
logger.debug_once(
|
|
"FlashInfer unavailable since nvcc was not found "
|
|
"and not using pre-downloaded cubins"
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
|
"""Placeholder for unavailable FlashInfer backend."""
|
|
raise RuntimeError(
|
|
"FlashInfer backend is not available. Please install the package "
|
|
"to enable FlashInfer kernels: "
|
|
"https://github.com/flashinfer-ai/flashinfer"
|
|
)
|
|
|
|
|
|
def _get_submodule(module_name: str) -> Any | None:
|
|
"""Safely import a submodule and return it, or None if not available."""
|
|
try:
|
|
return importlib.import_module(module_name)
|
|
except (ImportError, ModuleNotFoundError):
|
|
return None
|
|
|
|
|
|
# General lazy import wrapper
|
|
def _lazy_import_wrapper(
|
|
module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
|
|
):
|
|
"""Create a lazy import wrapper for a specific function."""
|
|
|
|
@functools.cache
|
|
def _get_impl():
|
|
if not has_flashinfer():
|
|
return None
|
|
mod = _get_submodule(module_name)
|
|
return getattr(mod, attr_name, None) if mod else None
|
|
|
|
def wrapper(*args, **kwargs):
|
|
impl = _get_impl()
|
|
if impl is None:
|
|
return fallback_fn(*args, **kwargs)
|
|
return impl(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
# Create lazy wrappers for each function
|
|
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
|
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
|
|
)
|
|
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
|
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
|
|
)
|
|
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
|
|
"flashinfer.fused_moe", "cutlass_fused_moe"
|
|
)
|
|
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
|
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
|
"flashinfer", "nvfp4_block_scale_interleave"
|
|
)
|
|
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
|
"flashinfer", "trtllm_fp4_block_scale_moe"
|
|
)
|
|
|
|
# Special case for autotune since it returns a context manager
|
|
autotune = _lazy_import_wrapper(
|
|
"flashinfer.autotuner",
|
|
"autotune",
|
|
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer_comm() -> bool:
|
|
"""Return `True` if FlashInfer comm module is available."""
|
|
return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer_all2all() -> bool:
|
|
"""Return `True` if FlashInfer mnnvl all2all is available."""
|
|
if not has_flashinfer_comm():
|
|
return False
|
|
|
|
# Check if all required functions are available
|
|
required_functions = [
|
|
("flashinfer.comm", "Mapping"),
|
|
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
|
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
|
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
|
]
|
|
|
|
for module_name, attr_name in required_functions:
|
|
mod = _get_submodule(module_name)
|
|
if not mod or not hasattr(mod, attr_name):
|
|
return False
|
|
return True
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer_moe() -> bool:
|
|
"""Return `True` if FlashInfer MoE module is available."""
|
|
return (
|
|
has_flashinfer()
|
|
and importlib.util.find_spec("flashinfer.fused_moe") is not None
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_flashinfer_cutlass_fused_moe() -> bool:
|
|
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
|
if not has_flashinfer_moe():
|
|
return False
|
|
|
|
# Check if all required functions are available
|
|
required_functions = [
|
|
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
|
("flashinfer", "fp4_quantize"),
|
|
("flashinfer", "nvfp4_block_scale_interleave"),
|
|
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
|
]
|
|
|
|
for module_name, attr_name in required_functions:
|
|
mod = _get_submodule(module_name)
|
|
if not mod or not hasattr(mod, attr_name):
|
|
return False
|
|
return True
|
|
|
|
|
|
@functools.cache
|
|
def has_nvidia_artifactory() -> bool:
|
|
"""Return `True` if NVIDIA's artifactory is accessible.
|
|
|
|
This checks connectivity to the kernel inference library artifactory
|
|
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
|
"""
|
|
# If we have pre-downloaded cubins, we can assume the cubins are available.
|
|
if has_flashinfer_cubin():
|
|
return True
|
|
|
|
try:
|
|
# Use a short timeout to avoid blocking for too long
|
|
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
|
accessible = response.status_code == 200
|
|
if accessible:
|
|
logger.debug_once("NVIDIA artifactory is accessible")
|
|
else:
|
|
logger.warning_once(
|
|
"NVIDIA artifactory returned failed status code: %d",
|
|
response.status_code,
|
|
)
|
|
return accessible
|
|
except Exception as e:
|
|
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def supports_trtllm_attention() -> bool:
|
|
"""
|
|
TRTLLM attention is supported if the platform is SM100,
|
|
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
|
"""
|
|
# Batch-invariant mode disables TRTLLM attention
|
|
if vllm_is_batch_invariant():
|
|
return False
|
|
|
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
|
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
|
|
|
|
|
@functools.cache
|
|
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
|
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
|
if env_value is not None:
|
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
|
return env_value
|
|
|
|
|
|
def force_use_trtllm_attention() -> bool | None:
|
|
"""
|
|
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
|
return `True` if TRTLLM attention is forced to be used,
|
|
return `False` if TRTLLM attention is forced to be not used.
|
|
"""
|
|
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
|
|
|
|
|
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
|
"""Check if the current configuration supports TRTLLM attention."""
|
|
if force_use_trtllm_attention() is False:
|
|
return False
|
|
has_trtllm = supports_trtllm_attention()
|
|
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
|
|
|
|
|
def use_trtllm_attention(
|
|
num_qo_heads: int,
|
|
num_kv_heads: int,
|
|
num_tokens: int,
|
|
max_seq_len: int,
|
|
dcp_world_size: int,
|
|
kv_cache_dtype: str,
|
|
q_dtype: torch.dtype,
|
|
is_prefill: bool,
|
|
has_sinks: bool = False,
|
|
has_spec: bool = False,
|
|
) -> bool:
|
|
"""Return `True` if TRTLLM attention is used."""
|
|
force_use_trtllm = force_use_trtllm_attention()
|
|
|
|
# Environment variable is set to 0 - respect it
|
|
if force_use_trtllm is not None and not force_use_trtllm:
|
|
return False
|
|
|
|
# Decode context parallel is not supported
|
|
if dcp_world_size > 1:
|
|
logger.warning_once(
|
|
"Trtllm does not support returning LSE and as a result "
|
|
"does not support DCP, reverting to FlashInfer"
|
|
)
|
|
return False
|
|
|
|
# The platform is not supported
|
|
if not supports_trtllm_attention():
|
|
if force_use_trtllm:
|
|
logger.warning_once(
|
|
"TRTLLM attention is not supported on this platform, "
|
|
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
|
)
|
|
return False
|
|
|
|
# The combination of query and key heads is not supported
|
|
if num_qo_heads % num_kv_heads != 0:
|
|
if force_use_trtllm:
|
|
logger.warning_once(
|
|
"TRTLLM attention is not supported for this combination of "
|
|
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
|
)
|
|
return False
|
|
|
|
if has_spec and not is_prefill:
|
|
# Speculative decoding requires TRTLLM attention for decodes
|
|
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
|
|
return True
|
|
|
|
# Must use TRTLLM attention if query is FP8 quantized
|
|
if q_dtype == current_platform.fp8_dtype():
|
|
logger.info_once("Using TRTLLM attention (query is quantized).")
|
|
return True
|
|
|
|
# If sinks are being used, we must use TRTLLM attention as it's
|
|
# the only backend that supports them
|
|
if has_sinks:
|
|
logger.info_once("Using TRTLLM attention (required for attention sinks).")
|
|
return True
|
|
|
|
if force_use_trtllm is None:
|
|
# Environment variable not set - use auto-detection
|
|
if is_prefill:
|
|
# Prefill auto-detection
|
|
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
|
|
if use_trtllm:
|
|
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
|
|
else:
|
|
# Decode auto-detection
|
|
use_trtllm = (
|
|
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
|
|
)
|
|
if use_trtllm:
|
|
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
|
return use_trtllm
|
|
|
|
# Environment variable is set to 1 - respect it
|
|
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
|
return True
|
|
|
|
|
|
if has_flashinfer():
|
|
|
|
@torch.library.custom_op(
|
|
"vllm::flashinfer_mm_fp4",
|
|
mutates_args=[],
|
|
device_types="cuda",
|
|
)
|
|
def flashinfer_mm_fp4(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
g_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
backend: str,
|
|
) -> torch.Tensor:
|
|
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
|
|
|
return flashinfer_mm_fp4_(
|
|
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
|
|
)
|
|
|
|
@torch.library.register_fake(
|
|
"vllm::flashinfer_mm_fp4",
|
|
)
|
|
def flashinfer_mm_fp4_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
g_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
backend: str,
|
|
) -> torch.Tensor:
|
|
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
|
|
|
|
@torch.library.custom_op(
|
|
"vllm::bmm_fp8",
|
|
mutates_args=[],
|
|
device_types="cuda",
|
|
)
|
|
def bmm_fp8(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
backend: str,
|
|
) -> torch.Tensor:
|
|
from flashinfer import bmm_fp8 as bmm_fp8_
|
|
|
|
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
|
|
|
@torch.library.register_fake(
|
|
"vllm::bmm_fp8",
|
|
)
|
|
def bmm_fp8_fake(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
backend: str,
|
|
) -> torch.Tensor:
|
|
return torch.empty(
|
|
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
|
|
)
|
|
|
|
|
|
def flashinfer_scaled_fp4_mm(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
block_scale_a: torch.Tensor,
|
|
block_scale_b: torch.Tensor,
|
|
alpha: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
backend: str,
|
|
) -> torch.Tensor:
|
|
assert a.ndim == 2 and b.ndim == 2
|
|
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
|
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
|
assert a.shape[1] == b.shape[1]
|
|
|
|
if backend == "cutlass":
|
|
block_scale_a = block_scale_a.view(torch.uint8)
|
|
block_scale_b = block_scale_b.view(torch.uint8)
|
|
|
|
return flashinfer_mm_fp4(
|
|
a,
|
|
b.t(),
|
|
block_scale_a,
|
|
block_scale_b.t(),
|
|
alpha,
|
|
out_dtype,
|
|
backend=backend,
|
|
)
|
|
|
|
|
|
def flashinfer_scaled_fp8_mm(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
scale_a: torch.Tensor,
|
|
scale_b: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
assert a.ndim == 2 and b.ndim == 2
|
|
assert a.shape[1] == b.shape[0]
|
|
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
|
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
|
assert a.device.type == "cuda" and b.device.type == "cuda"
|
|
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
|
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
|
|
|
output = bmm_fp8(
|
|
a.unsqueeze(0),
|
|
b.unsqueeze(0),
|
|
scale_a,
|
|
scale_b,
|
|
out_dtype,
|
|
"auto",
|
|
).view(a.shape[0], b.shape[1])
|
|
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output
|
|
|
|
|
|
@functools.cache
|
|
def flashinfer_disable_q_quantization() -> bool:
|
|
"""Cache result which only depends on the environment"""
|
|
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
|
|
|
|
|
__all__ = [
|
|
"has_flashinfer",
|
|
"flashinfer_trtllm_fp8_block_scale_moe",
|
|
"flashinfer_cutlass_fused_moe",
|
|
"flashinfer_fp4_quantize",
|
|
"nvfp4_block_scale_interleave",
|
|
"trtllm_fp4_block_scale_moe",
|
|
"autotune",
|
|
"has_flashinfer_moe",
|
|
"has_flashinfer_comm",
|
|
"has_flashinfer_all2all",
|
|
"has_flashinfer_cutlass_fused_moe",
|
|
"has_nvidia_artifactory",
|
|
"supports_trtllm_attention",
|
|
"can_use_trtllm_attention",
|
|
"use_trtllm_attention",
|
|
"flashinfer_disable_q_quantization",
|
|
"flashinfer_scaled_fp4_mm",
|
|
"flashinfer_scaled_fp8_mm",
|
|
]
|