[Kernel] Add nvfp4 gemm flashinfer backends (#22346)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
nvjullin 2025-08-15 04:03:55 +08:00 committed by GitHub
parent b8ff05361a
commit 279a5f31b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 369 additions and 39 deletions

View File

@ -669,6 +669,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
@pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode()
def test_flashinfer_nvfp4_gemm(
dtype: torch.dtype,
shape: tuple[int, int, int],
seed: int,
device: str,
backend: str,
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
pytest.skip(
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
current_platform.seed_everything(seed)
m, n, packed_k = shape
k = packed_k * 2
block_size = 16
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
)
import flashinfer
if backend == "trtllm":
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size)
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp4_mm(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
alpha,
dtype,
backend=backend,
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)

View File

@ -65,9 +65,12 @@ def test_nvfp4_gemm(
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1. / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, a_global_scale,
b_global_scale, m, n, dtype, block_size,

View File

@ -1101,6 +1101,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
"VLLM_USE_TRTLLM_FP4_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
# If set to 1, allows GC to run during capture.
@ -1208,6 +1214,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_USE_TRTLLM_FP4_GEMM",
]
for key in environment_variables_to_hash:
if key in environment_variables:

View File

@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
logger = init_logger(__name__)
@ -24,6 +25,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
else:
self.backend = "cutlass"
self.group_size = 16
@classmethod
@ -108,16 +116,36 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
if self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
# required by cutlass kernel; need Parameter, not ModelWeightParameter
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8),
epilogue_tile_m)
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
@ -128,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
out = run_nvfp4_emulations(
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale_swizzled,
weight_global_scale=layer.weight_global_scale)
if bias is not None:
@ -136,14 +164,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale_swizzled, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
else:
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@ -38,7 +38,8 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.scalar_type import scalar_types
from vllm.utils import next_power_of_2
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
has_flashinfer_moe)
logger = init_logger(__name__)
@ -724,16 +725,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
def create_weights(
self,
@ -815,17 +820,38 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
# block_size = 16;
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
weight = layer.weight.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8),
epilogue_tile_m)
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.backend == "marlin":
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
def apply(
self,
@ -833,7 +859,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
if self.backend == "marlin":
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
@ -859,9 +885,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled, layer.alpha,
output_dtype)
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale_swizzled,
layer.alpha,
output_dtype,
)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
else:
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@ -5,16 +5,53 @@ Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu_worker import Worker
def kernel_warmup(model: torch.nn.Module, max_tokens: int):
def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
if do_deep_gemm_warmup:
model = worker.get_model()
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.is_device_capability(100):
flashinfer_autotune(worker.model_runner)
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Autotune FlashInfer operations.
FlashInfer have many implementations for the same operation,
autotuning runs benchmarks for each implementation and stores
the results. The results are cached transparently and
future calls to FlashInfer will use the best implementation.
Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower.
"""
from vllm.utils.flashinfer import autotune
with torch.inference_mode(), autotune():
# We skip EPLB here since we don't want to record dummy metrics
# When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m.
# So we only need to run with the max number of tokens.
runner._dummy_run(runner.scheduler_config.max_num_batched_tokens,
skip_eplb=True,
is_profile=True)

View File

@ -14,6 +14,7 @@ import os
from typing import Any, Callable, NoReturn, Optional
import requests
import torch
import vllm.envs as envs
from vllm.logger import init_logger
@ -193,6 +194,75 @@ def use_trtllm_attention(
return use_trtllm
if has_flashinfer():
@torch.library.custom_op(
"vllm::flashinfer_mm_fp4",
mutates_args=[],
device_types="cuda",
)
def flashinfer_mm_fp4(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
return flashinfer_mm_fp4_(A,
B,
A_scale,
B_scale,
g_scale,
dtype,
block_size=16,
backend=backend)
@torch.library.register_fake("vllm::flashinfer_mm_fp4", )
def flashinfer_mm_fp4_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
return torch.empty(A.shape[0],
B.shape[1],
dtype=dtype,
device=A.device)
def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor, alpha: torch.Tensor,
out_dtype: torch.dtype,
backend: str) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1]
assert block_scale_a.shape[1] == a.shape[1] // 8
assert block_scale_b.shape[1] == b.shape[1] // 8
if backend == "cutlass":
block_scale_a = block_scale_a.view(torch.uint8)
block_scale_b = block_scale_b.view(torch.uint8)
return flashinfer_mm_fp4(
a,
b.t(),
block_scale_a,
block_scale_b.t(),
alpha,
out_dtype,
backend=backend,
)
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@ -205,4 +275,5 @@ __all__ = [
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
]

View File

@ -310,6 +310,7 @@ class Worker(WorkerBase):
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size, skip_eplb=True)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
@ -340,8 +341,7 @@ class Worker(WorkerBase):
hidden_states=last_hidden_states)
# Warmup kernels used during model execution
kernel_warmup(self.get_model(),
max_tokens=self.scheduler_config.max_num_batched_tokens)
kernel_warmup(self)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.