diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 5575ee8e4197..6d69852bb4e4 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -31,6 +31,13 @@ namespace vllm { +template +__host__ __device__ inline Int round_up(Int x, Int y) { + static_assert(std::is_integral_v, + "round_up argument must be integral type"); + return (x + y - 1) / y * y; +} + // Use UE4M3 by default. template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) @@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + int sf_m = round_up(numRows, 128); + int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; + int sf_n_int = round_up(sf_n_unpadded, 4) / 4; + for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) { + // Each thread writes 4 uint32_t elements. + for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int; + col += blockDim.x * 4) { + SFout[row * sf_n_int + col] = 0x00; + } + } + // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { @@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) rowIdx, colIdx, numCols, SFout); out_pos = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); } } } diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index e9b091d06697..12f1008ecf27 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -168,9 +168,7 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) out, out_scale = ops.scaled_fp4_quant(x, global_scale) - scale_ans = recover_swizzled_scales(out_scale, m, n) out_ans = cast_from_fp4(out, m, n) - torch.testing.assert_close(out_ans, out_ref) torch.testing.assert_close(scale_ans, scale_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cfcf534c613f..de68b3418244 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1385,7 +1385,7 @@ def scaled_fp4_quant( rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) - output_scale = torch.zeros( + output_scale = torch.empty( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) diff --git a/vllm/envs.py b/vllm/envs.py index eb50ea6e5dbe..59a6bef58c9c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -155,7 +155,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput" + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1218,7 +1218,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( - "VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"] + "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"] ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for @@ -1325,7 +1325,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND", None, - ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"], + ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 4127cd2d574b..b603bdb13280 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -50,6 +50,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): self.backend = envs.VLLM_NVFP4_GEMM_BACKEND assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": + self.backend = "cutlass" + assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" if self.backend == "none": raise ValueError( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 03eca199d536..ce40645782e5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -138,6 +138,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: + if block_quant: + raise ValueError( + "FlashInfer FP8 MoE throughput backend does not " + "support block quantization. Please use " + "VLLM_FLASHINFER_MOE_BACKEND=latency " + "instead." + ) logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100") return Fp8MoeBackend.FLASHINFER_CUTLASS diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0946cc171fa7..e14753c60c48 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -221,7 +221,10 @@ class ModelOptFp8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import + from vllm.attention.layer import ( # Avoid circular import + Attention, + MLAAttention, + ) if isinstance(layer, LinearBase): if self.is_layer_excluded(prefix): @@ -230,7 +233,7 @@ class ModelOptFp8Config(QuantizationConfig): if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) - elif isinstance(layer, Attention): + elif isinstance(layer, (Attention, MLAAttention)): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): return ModelOptFp8MoEMethod(self, layer) @@ -888,7 +891,10 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import + from vllm.attention.layer import ( # Avoid circular import + Attention, + MLAAttention, + ) skip_layer = self.is_layer_excluded(prefix) if isinstance(layer, LinearBase): @@ -898,7 +904,7 @@ class ModelOptNvFp4Config(QuantizationConfig): if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) - elif isinstance(layer, Attention): + elif isinstance(layer, (Attention, MLAAttention)): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): if skip_layer: @@ -941,6 +947,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): self.backend = envs.VLLM_NVFP4_GEMM_BACKEND assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": + self.backend = "cutlass" + assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" if self.backend == "none": raise ValueError(