[Bugfix] Use latency MOE backend as default for Flashinfer and other misc fixes (#27439)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-11-07 04:18:39 -08:00 committed by GitHub
parent e0919f331d
commit 72b1c2ae2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 47 additions and 12 deletions

View File

@ -31,6 +31,13 @@
namespace vllm {
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return (x + y - 1) / y * y;
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
// Each thread writes 4 uint32_t elements.
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
col += blockDim.x * 4) {
SFout[row * sf_n_int + col] = 0x00;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx, colIdx, numCols, SFout);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
}
}
}

View File

@ -168,9 +168,7 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
scale_ans = recover_swizzled_scales(out_scale, m, n)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)

View File

@ -1385,7 +1385,7 @@ def scaled_fp4_quant(
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.zeros(
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)

View File

@ -155,7 +155,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@ -1218,7 +1218,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
@ -1325,7 +1325,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"],
),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.

View File

@ -50,6 +50,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(

View File

@ -138,6 +138,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant:
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS

View File

@ -221,7 +221,10 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
@ -230,7 +233,7 @@ class ModelOptFp8Config(QuantizationConfig):
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self, layer)
@ -888,7 +891,10 @@ class ModelOptNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
skip_layer = self.is_layer_excluded(prefix)
if isinstance(layer, LinearBase):
@ -898,7 +904,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
if skip_layer:
@ -941,6 +947,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(