diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 882b034e2f230..9fd72ee152b55 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -11,6 +11,7 @@ import torch from packaging import version from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( @@ -19,6 +20,10 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( ) and current_platform.is_device_capability(100) +HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer()) + if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, @@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe( transpose_optimized=transpose_optimized) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) + + +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales on the last dimension by groups of 4, matching + the transformation in mxfp4.py's BF16 (Hopper) path.""" + s = scales.to(torch.uint8) + s_shape = s.shape + assert s_shape[-1] % 4 == 0 + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) + # Move the 4-group dimension before the row dimension + permuted = s.permute(0, 2, 1, 3) + # Merge the row dim with the 4-group dim + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not HOPPER_MXFP4_BF16_AVAILABLE, + reason="nvidia gpu sm90 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] + w13_q = torch.randint( + 0, + 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, + dtype=torch.uint8) + w13_scale = torch.randint( + 118, + 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, + dtype=torch.uint8) + + w2_q = torch.randint(0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8) + w2_scale = torch.randint( + 118, + 123, (num_experts, hidden_size, intermediate_size // 32), + device=device, + dtype=torch.uint8) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( + num_experts, 2 * intermediate_size, hidden_size) + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( + num_experts, hidden_size, intermediate_size) + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'bf16') + + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) + w13_s = torch.cat([w3_s, w1_s], dim=1) + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) + + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped, + fc2_expert_weights=w2_q, + output_dtype=torch.bfloat16, + output=out, + quant_scales=[w13_s_inter.to(torch.uint8), + w2_s_inter.to(torch.uint8)], + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha, + swiglu_beta=beta, + swiglu_limit=limit, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_w4_group_scaling=True, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not (current_platform.is_cuda() + and current_platform.is_device_capability(100) and has_flashinfer()), + reason="NVIDIA GPU sm100 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: Optional[float], + beta: Optional[float], + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Float weights in w13 format [w1; w3] + w13 = (torch.randn(num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16) / 10) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16) / 10) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + # Quantize weights to MXFP4 per expert (SM100 path) + from flashinfer import mxfp4_quantize + + def quant_mxfp4_batches(a: torch.Tensor, e: int): + qs, sfs = [], [] + for i in range(e): + q, sf = mxfp4_quantize(a[i].cuda()) + qs.append(q) + sfs.append(sf) + return torch.stack(qs), torch.stack(sfs) + + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor): + num_batches = mat_fp4.size(0) + scale_tensor = scale_tensor.view(num_batches, -1) + from flashinfer import mxfp4_dequantize + return torch.stack([ + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) + for b in range(num_batches) + ]) + + w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) + w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) + + # Reference result using dequantized tensors and reference_moe + w13_ref = dequant_mxfp4_batches( + w13_q.view(torch.uint8), + w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, 2 * intermediate_size, hidden_size) + w2_ref = dequant_mxfp4_batches( + w2_q.view(torch.uint8), + w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, hidden_size, intermediate_size) + + # Quantize activations for SM100 path and dequantize for reference + hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) + # Reference uses BF16 input but quantizes intermediate activation to MXFP8 + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'mxfp8') + + # Prepare inputs for FlashInfer CUTLASS fused MoE + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + # Swap scales halves to match swapped weights + s1, s3 = torch.chunk(w13_scale, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + # Build routing for kernel + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha_t = torch.full((num_experts, ), + alpha, + device=hidden_states.device) + else: + alpha_t = None + if beta is not None: + beta_t = torch.full((num_experts, ), beta, device=hidden_states.device) + else: + beta_t = None + if limit is not None: + limit_t = torch.full((num_experts, ), + limit, + device=hidden_states.device) + else: + limit_t = None + + # Quant scales for SM100 MXFP8+MXFP4 path + fake_input_scale = torch.ones(num_experts, device=device) + quant_scales = [ + w13_scale_swapped.view(torch.int32), + fake_input_scale, + w2_scale.view(torch.int32), + fake_input_scale, + ] + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states_q, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), + fc2_expert_weights=w2_q.contiguous().view(torch.long), + output_dtype=torch.bfloat16, + output=out, + quant_scales=quant_scales, + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha_t, + swiglu_beta=beta_t, + swiglu_limit=limit_t, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_mxfp8_act_scaling=True, + input_sf=hidden_states_sf, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) diff --git a/vllm/envs.py b/vllm/envs.py index d7956a2adff68..165cd32721fe5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -166,7 +166,8 @@ if TYPE_CHECKING: VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False @@ -1004,6 +1005,15 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + # If set to 1, use the FlashInfer CUTLASS backend for + # MXFP8 (activation) x MXFP4 (weight) MoE. + # This is separate from the TRTLLMGEN path controlled by + # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": + lambda: bool(int( + os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0") + )), + # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": @@ -1296,6 +1306,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 89676f98cb0ed..a90a71159f721 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -813,9 +813,16 @@ class FusedMoE(CustomOp): # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or + current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 74922756afe56..f935bdd84124a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Callable, Optional, Union import torch @@ -33,33 +34,72 @@ from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) -def _should_use_flashinfer_mxfp4_bf16(): - """Determine if FlashInfer MXFP4 BF16 should be used.""" - # If explicitly set, respect the setting - if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 - # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) and has_flashinfer() - and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): - logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " - "For faster performance, consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " - "though this may impact accuracy.") - return True + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 - return False + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 -def _should_use_flashinfer_mxfp4_mxfp8(): - """Determine if FlashInfer MXFP4 MXFP8 should be used.""" - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 +def get_mxfp4_backend(): + # Backend Selection + if current_platform.is_cuda(): + if (current_platform.is_device_capability(90) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " + "for high concurrency throughput workloads consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " + "performance") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy.") + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ((current_platform.is_device_capability(100) + or current_platform.is_device_capability(90)) + and not has_flashinfer()): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + # If FlashInfer is not available, try either Marlin or Triton + if current_platform.get_device_capability( + )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( + "2.8.0"): + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON -def should_use_flashinfer_mxfp4(): - return (_should_use_flashinfer_mxfp4_mxfp8() - or _should_use_flashinfer_mxfp4_bf16()) + return Mxfp4Backend.NONE class Mxfp4Config(QuantizationConfig): @@ -113,31 +153,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe - self.use_marlin = self._should_use_marlin() + self.mxfp4_backend = get_mxfp4_backend() self.max_capture_size = get_current_vllm_config( ).compilation_config.max_capture_size - if current_platform.is_device_capability(100) and not has_flashinfer(): - logger.warning_once( - "MXFP4 MoE is enabled on Blackwell but FlashInfer " - "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results.") + assert self.mxfp4_backend != Mxfp4Backend.NONE, ( + "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." + "Please check your environment and try again.") self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} - def _should_use_marlin(self): - if envs.VLLM_MXFP4_USE_MARLIN is not None: - return envs.VLLM_MXFP4_USE_MARLIN - if current_platform.is_cuda() and \ - not current_platform.is_device_capability(100): - if not current_platform.has_device_capability(90): - # marlin kernel has better performance on ampere - return True - if not has_triton_kernels(): - return True - if not is_torch_equal_or_newer("2.8.0"): - return True - return False - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -157,7 +181,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = \ intermediate_size_per_partition - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. # In gate_up_proj: @@ -175,16 +199,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4(): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif current_platform.is_rocm(): + elif current_platform.is_rocm() or ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 128) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) @@ -264,9 +292,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4(): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): from flashinfer.fp4_quantization import ( nvfp4_block_scale_interleave) from flashinfer.fused_moe.core import ( @@ -429,7 +458,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + + sf_block_size = 32 # mxfp4 block size + + # Common shape assertions + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + # De-interleave and swap for w13 weight, bias, and scales + w13_w = layer.w13_weight.data + gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] + deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) + w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + + w13_b = layer.w13_bias.data.to(torch.float32) + gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] + deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) + w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w13_s = layer.w13_weight_scale.data + gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] + deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) + s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import block_scale_interleave + + orig_shape = w13_scale_swapped.shape + w13_scale_interleaved = block_scale_interleave( + w13_scale_swapped.view(torch.uint8)).reshape(orig_shape) + + w2_s = layer.w2_weight_scale.data + orig_shape = w2_s.shape + w2_scale_interleaved = block_scale_interleave( + w2_s.view(torch.uint8)).reshape(orig_shape) + + layer.w13_weight = Parameter(w13_weight_swapped, + requires_grad=False) + layer.w13_weight_scale = Parameter(w13_scale_interleaved, + requires_grad=False) + layer.w13_bias = Parameter(w13_bias_swapped, + requires_grad=False) + layer.w2_weight_scale = Parameter(w2_scale_interleaved, + requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape(w_shape[0], w_shape[1], + (w_shape[2] // 4), 4) + w_interleaved = w_interleaved.permute(0, 2, 1, 3) + w_interleaved = w_interleaved.reshape( + w_shape[0], w_shape[2] // 4, w_shape[1] * 4) + return w_interleaved + + w31_scales = w13_scale_swapped.to(torch.uint8).view( + torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w31_scales) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) + w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w2_scales) + + layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w], + dim=1), + requires_grad=False) + layer.w13_bias = torch.nn.Parameter(w13_bias_swapped, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w31_scales_interleaved, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales_interleaved, requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig w13_bias = layer.w13_bias.to(torch.float32) @@ -464,6 +602,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_weight = None layer.w2_weight = None torch.cuda.empty_cache() + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): # Number of tokens in the input tensor. @@ -500,7 +640,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError( "Mxfp4 does not support batched experts format for EP") else: - if should_use_flashinfer_mxfp4(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # B200 code-path kwargs = { "gemm1_alpha": layer.gemm1_alpha, @@ -601,7 +742,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -665,16 +806,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if should_use_flashinfer_mxfp4(): - from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe - if _should_use_flashinfer_mxfp4_bf16(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + from flashinfer import trtllm_fp4_block_scale_moe + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None - else: + elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: + from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape( *x.shape[:-1], -1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -706,7 +850,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # Backend-specific preparation + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + + from flashinfer import mxfp8_quantize + + x_quant, x_scale = mxfp8_quantize(x, True, 32) + + fake_input_scale = torch.ones(self.num_experts, + device=x.device) + quant_scales = [ + layer.w13_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + layer.w2_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + ] + + fi_input = x_quant + extra_kwargs = dict( + use_mxfp8_act_scaling=True, + input_sf=x_scale, + fc1_expert_weights=layer.w13_weight.contiguous().view( + torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view( + torch.long), + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + fi_input = x + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + ) + + output = torch.empty_like(x, dtype=torch.bfloat16) + _ = flashinfer_cutlass_fused_moe( + input=fi_input, + token_selected_experts=topk_ids.to(torch.int).contiguous(), + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, + tune_max_num_tokens=self.max_capture_size, + **extra_kwargs, + ) + + return output + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward) return triton_kernel_moe_forward( @@ -724,3 +947,5 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w2_precision=self.w2_precision_config, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index e42e34ebc77b9..89ce20308f447 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -33,8 +33,8 @@ def kernel_warmup(worker: "Worker"): max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) - # FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.is_device_capability(100): + # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) # FlashInfer attention warmup