diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 740be2bc8770..942a8d3f9bfd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -669,6 +669,7 @@ steps: - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py new file mode 100644 index 000000000000..131086a5f703 --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + convert_swizzled_to_linear, dequantize_nvfp4_to_dtype) + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + seed: int, + device: str, + backend: str, + autotune: bool, +) -> None: + if backend == "trtllm" and dtype == torch.float16: + pytest.skip( + "Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + + current_platform.seed_everything(seed) + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device=device) + b_dtype = torch.randn((n, k), dtype=dtype, device=device) + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. + # So instead of needing to swizzle for cutlass as in modelopt.py, + # we need to unswizzle for trtllm here. + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + + # get_ref_results unswizzles the scales internally. + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + + import flashinfer + + if backend == "trtllm": + epilogue_tile_m = 128 + b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), + epilogue_tile_m) + + b_scale_interleaved = convert_swizzled_to_linear( + b_scale_interleaved, n, k, block_size) + b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a( + b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape( + b_scale_interleaved.shape).view(torch.float8_e4m3fn)) + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp4_mm( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + alpha, + dtype, + backend=backend, + ) + + torch.testing.assert_close(out, + expected_out.to(dtype=dtype), + atol=1e-1, + rtol=1e-1) diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 0b45c2298175..67e041f2b71c 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -65,9 +65,12 @@ def test_nvfp4_gemm( b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) alpha = 1. / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + # get_ref_results unswizzles the scales internally. expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, a_global_scale, b_global_scale, m, n, dtype, block_size, diff --git a/vllm/envs.py b/vllm/envs.py index 110bb542b120..2f0bafa01cc2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1101,6 +1101,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. + # Otherwise, uses the first available of: flashinfer cutlass GEMM, + # vllm cutlass GEMM, marlin GEMM. + "VLLM_USE_TRTLLM_FP4_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))), + # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. @@ -1208,6 +1214,7 @@ def compute_hash() -> str: "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", "VLLM_FUSED_MOE_CHUNK_SIZE", + "VLLM_USE_TRTLLM_FP4_GEMM", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 8ba72162921a..63bfe565b121 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer logger = init_logger(__name__) @@ -24,6 +25,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + else: + self.backend = "cutlass" self.group_size = 16 @classmethod @@ -108,16 +116,36 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): layer.weight_global_scale.max().to(torch.float32), requires_grad=False) - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - # required by cutlass kernel; need Parameter, not ModelWeightParameter - layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + weight = layer.weight_packed.data + weight_scale = layer.weight_scale.data - layer.alpha = Parameter(layer.input_global_scale * - layer.weight_global_scale, - requires_grad=False) + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale_swizzled = Parameter(weight_scale, + requires_grad=False) + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + + layer.alpha = Parameter( + 1 / (layer.input_global_scale * layer.weight_global_scale), + requires_grad=False) def apply_weights(self, layer: torch.nn.Module, @@ -128,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): out = run_nvfp4_emulations( x=x, input_global_scale=layer.input_global_scale, - weight=layer.weight, + weight=layer.weight_packed, weight_scale_swizzled=layer.weight_scale_swizzled, weight_global_scale=layer.weight_global_scale) if bias is not None: @@ -136,14 +164,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return out output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] + output_shape = [x.shape[0], layer.weight_packed.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, - 1 / layer.alpha, output_dtype) + mm_args = (x_fp4, layer.weight_packed, x_blockscale, + layer.weight_scale_swizzled, layer.alpha, output_dtype) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8868c623796a..8f9ca73bc505 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -38,7 +38,8 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.scalar_type import scalar_types from vllm.utils import next_power_of_2 -from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, + has_flashinfer_moe) logger = init_logger(__name__) @@ -724,16 +725,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - self.use_marlin = False - if not self.cutlass_nvfp4_supported: - if is_fp4_marlin_supported(): - self.use_marlin = True - else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + elif is_fp4_marlin_supported(): + self.backend = "marlin" + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, @@ -815,17 +820,38 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): # block_size = 16; assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) - layer.weight = Parameter(layer.weight.data, requires_grad=False) + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - if self.use_marlin: - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - del layer.weight_scale_swizzled + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale_swizzled = Parameter(weight_scale, + requires_grad=False) + layer.weight = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + if self.backend == "marlin": + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled def apply( self, @@ -833,7 +859,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if self.use_marlin: + if self.backend == "marlin": return apply_fp4_marlin_linear( input=x, weight=layer.weight, @@ -859,9 +885,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, - output_dtype) + mm_args = ( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale_swizzled, + layer.alpha, + output_dtype, + ) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 10f2dc0252a1..761172e4d361 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -5,16 +5,53 @@ Warmup kernels used during model execution. This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution. """ +from typing import TYPE_CHECKING + import torch import vllm.envs as envs from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.gpu_worker import Worker -def kernel_warmup(model: torch.nn.Module, max_tokens: int): +def kernel_warmup(worker: "Worker"): + # Deep GEMM warmup do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) if do_deep_gemm_warmup: + model = worker.get_model() + max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) + + # FlashInfer autotune for Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.is_device_capability(100): + flashinfer_autotune(worker.model_runner) + + +def flashinfer_autotune(runner: "GPUModelRunner") -> None: + """ + Autotune FlashInfer operations. + FlashInfer have many implementations for the same operation, + autotuning runs benchmarks for each implementation and stores + the results. The results are cached transparently and + future calls to FlashInfer will use the best implementation. + Without autotuning, FlashInfer will rely on heuristics, which may + be significantly slower. + """ + from vllm.utils.flashinfer import autotune + + with torch.inference_mode(), autotune(): + # We skip EPLB here since we don't want to record dummy metrics + # When autotuning with number of tokens m, flashinfer will autotune + # operations for all number of tokens up to m. + # So we only need to run with the max number of tokens. + runner._dummy_run(runner.scheduler_config.max_num_batched_tokens, + skip_eplb=True, + is_profile=True) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6b23ed426806..0d7d4b694f07 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -14,6 +14,7 @@ import os from typing import Any, Callable, NoReturn, Optional import requests +import torch import vllm.envs as envs from vllm.logger import init_logger @@ -193,6 +194,75 @@ def use_trtllm_attention( return use_trtllm +if has_flashinfer(): + + @torch.library.custom_op( + "vllm::flashinfer_mm_fp4", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + return flashinfer_mm_fp4_(A, + B, + A_scale, + B_scale, + g_scale, + dtype, + block_size=16, + backend=backend) + + @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + def flashinfer_mm_fp4_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + B.shape[1], + dtype=dtype, + device=A.device) + + +def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + assert block_scale_a.shape[1] == a.shape[1] // 8 + assert block_scale_b.shape[1] == b.shape[1] // 8 + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -205,4 +275,5 @@ __all__ = [ "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "use_trtllm_attention", + "flashinfer_scaled_fp4_mm", ] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0ea23921a080..84f065f25f2e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -310,6 +310,7 @@ class Worker(WorkerBase): for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) + if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -340,8 +341,7 @@ class Worker(WorkerBase): hidden_states=last_hidden_states) # Warmup kernels used during model execution - kernel_warmup(self.get_model(), - max_tokens=self.scheduler_config.max_num_batched_tokens) + kernel_warmup(self) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling.