mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 09:37:03 +08:00
[NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107)
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a1063628a4
commit
de92d916fe
17
vllm/envs.py
17
vllm/envs.py
@ -191,6 +191,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||||
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
||||||
|
VLLM_NVFP4_GEMM_BACKEND: str | None = None
|
||||||
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
|
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
|
||||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||||
@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# If set, it means we pre-downloaded cubin files and flashinfer will
|
# If set, it means we pre-downloaded cubin files and flashinfer will
|
||||||
# read the cubin files directly.
|
# read the cubin files directly.
|
||||||
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
|
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
|
||||||
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
# Supported options:
|
||||||
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
|
||||||
# vllm cutlass GEMM, marlin GEMM.
|
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
|
||||||
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
|
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
|
||||||
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
|
# - <none>: automatically pick an available backend
|
||||||
|
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
|
||||||
|
"VLLM_NVFP4_GEMM_BACKEND",
|
||||||
|
None,
|
||||||
|
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
|
||||||
),
|
),
|
||||||
# Controls garbage collection during CUDA graph capture.
|
# Controls garbage collection during CUDA graph capture.
|
||||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||||
@ -1492,7 +1497,6 @@ def compute_hash() -> str:
|
|||||||
"VLLM_DISABLED_KERNELS",
|
"VLLM_DISABLED_KERNELS",
|
||||||
"VLLM_USE_DEEP_GEMM",
|
"VLLM_USE_DEEP_GEMM",
|
||||||
"VLLM_USE_DEEP_GEMM_E8M0",
|
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
|
||||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||||
@ -1524,6 +1528,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
|
"VLLM_NVFP4_GEMM_BACKEND",
|
||||||
"VLLM_USE_FBGEMM",
|
"VLLM_USE_FBGEMM",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
|
|||||||
@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||||
run_nvfp4_emulations,
|
run_nvfp4_emulations,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
cutlass_fp4_supported,
|
||||||
|
swizzle_blockscale,
|
||||||
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
|||||||
|
|
||||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
self.backend = "none"
|
||||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||||
self.backend = "flashinfer-trtllm"
|
if has_flashinfer():
|
||||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
self.backend = "flashinfer-cutlass"
|
||||||
|
elif cutlass_fp4_supported():
|
||||||
|
self.backend = "cutlass"
|
||||||
elif envs.VLLM_USE_FBGEMM:
|
elif envs.VLLM_USE_FBGEMM:
|
||||||
self.backend = "fbgemm"
|
self.backend = "fbgemm"
|
||||||
try:
|
try:
|
||||||
@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||||
"Please install with: pip install fbgemm-gpu-genai"
|
"Please install with: pip install fbgemm-gpu-genai"
|
||||||
) from exc
|
) from exc
|
||||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||||
elif has_flashinfer():
|
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||||
self.backend = "flashinfer-cutlass"
|
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
|
||||||
else:
|
if self.backend == "none":
|
||||||
self.backend = "cutlass"
|
raise ValueError(
|
||||||
|
"No valid NVFP4 GEMM backend found. "
|
||||||
|
"Please check your platform capability."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
layer.alpha,
|
layer.alpha,
|
||||||
output_dtype,
|
output_dtype,
|
||||||
)
|
)
|
||||||
if self.backend == "flashinfer-trtllm":
|
if self.backend.startswith("flashinfer-"):
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
backend_name = self.backend[len("flashinfer-") :]
|
||||||
elif self.backend == "flashinfer-cutlass":
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
|
||||||
elif self.backend == "fbgemm":
|
elif self.backend == "fbgemm":
|
||||||
out = torch.ops.fbgemm.f4f4bf16(
|
out = torch.ops.fbgemm.f4f4bf16(
|
||||||
x_fp4,
|
x_fp4,
|
||||||
@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
use_mx=False,
|
use_mx=False,
|
||||||
).to(output_dtype)
|
).to(output_dtype)
|
||||||
else:
|
else:
|
||||||
|
assert self.backend == "cutlass"
|
||||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
|||||||
@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
self.backend = "none"
|
||||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||||
self.backend = "flashinfer-trtllm"
|
if has_flashinfer():
|
||||||
elif has_flashinfer():
|
self.backend = "flashinfer-cutlass"
|
||||||
self.backend = "flashinfer-cutlass"
|
elif cutlass_fp4_supported():
|
||||||
elif cutlass_fp4_supported():
|
self.backend = "cutlass"
|
||||||
self.backend = "cutlass"
|
elif is_fp4_marlin_supported():
|
||||||
elif is_fp4_marlin_supported():
|
self.backend = "marlin"
|
||||||
self.backend = "marlin"
|
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||||
else:
|
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||||
|
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||||
|
|
||||||
|
if self.backend == "none":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Current platform does not support NVFP4"
|
"No valid NVFP4 GEMM backend found. "
|
||||||
" quantization. Please use Blackwell and"
|
"Please check your platform capability."
|
||||||
" above."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
layer.alpha,
|
layer.alpha,
|
||||||
output_dtype,
|
output_dtype,
|
||||||
)
|
)
|
||||||
if self.backend == "flashinfer-trtllm":
|
if self.backend.startswith("flashinfer-"):
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
backend_name = self.backend[len("flashinfer-") :]
|
||||||
elif self.backend == "flashinfer-cutlass":
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
|
||||||
else:
|
else:
|
||||||
|
assert self.backend == "cutlass"
|
||||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user