From a38b8af4c3647ca615e3f1e334779cda8b1eddd8 Mon Sep 17 00:00:00 2001 From: amirkl94 <203507526+amirkl94@users.noreply.github.com> Date: Wed, 20 Aug 2025 01:01:53 +0300 Subject: [PATCH] [NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 2 + tests/kernels/moe/test_flashinfer.py | 248 ++++++++++++++++++ .../fused_moe/flashinfer_cutlass_moe.py | 51 ++-- .../model_executor/layers/quantization/fp8.py | 221 ++++++++++------ .../layers/quantization/modelopt.py | 118 ++++++--- .../quantization/utils/flashinfer_utils.py | 112 ++++++++ 6 files changed, 613 insertions(+), 139 deletions(-) create mode 100644 tests/kernels/moe/test_flashinfer.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 265e6ad72a5f..781b8e0fa009 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -630,6 +630,7 @@ steps: - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py - vllm/compilation/fusion_attn.py @@ -650,6 +651,7 @@ steps: # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern + - pytest -v -s tests/kernels/moe/test_flashinfer.py ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py new file mode 100644 index 000000000000..52a3d2ca3b42 --- /dev/null +++ b/tests/kernels/moe/test_flashinfer.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + input_to_float8) +from vllm.model_executor.models.llama4 import Llama4MoE +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +NUM_EXPERTS = [16] +TOP_KS = [1] + +MNK_FACTORS = [ + (256, 8192, 5120), + (256, 4096, 5120), + (127, 8192, 5120), + (127, 4096, 5120), + (10, 8192, 5120), + (10, 4096, 5120), + (1, 8192, 5120), + (1, 4096, 5120), +] + +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + + +def quant_fp8_per_tensor_batches(a): + num_batches = a.size(0) + a_quant = [] + a_scales = [] + + for i in range(num_batches): + a_fp8, a_global_sf = input_to_float8(a[i]) + a_global_sf = 1.0 / a_global_sf + a_quant.append(a_fp8) + a_scales.append(a_global_sf) + + result_a_quant = torch.stack(a_quant) + result_a_scales = torch.stack(a_scales) + + return result_a_quant, result_a_scales + + +@dataclass +class TestData: + hidden_states: torch.Tensor + w13_quantized: torch.Tensor + w2_quantized: torch.Tensor + a1_scale: torch.Tensor + a2_scale: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor + layer: torch.nn.Module + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + reorder: bool) -> "TestData": + hidden_states = torch.randn( + (m, k), device="cuda", dtype=torch.bfloat16) / 10 + w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) + + # Scale to fp8 + _, a1_scale = input_to_float8(hidden_states) + a1_scale = 1.0 / a1_scale + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( + dtype=torch.float32) + w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) + w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) + + layer = torch.nn.Module() + layer.w13_weight = w13_quantized.clone() + layer.w2_weight = w2_quantized.clone() + layer.w13_input_scale = a1_scale + layer.w2_input_scale = a2_scale + layer.w13_weight_scale = w13_weight_scale + layer.w2_weight_scale = w2_weight_scale + + register_moe_scaling_factors(layer) + + # flashinfer expects swapped rows for w13 + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + if reorder: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + layer.custom_routing_function = Llama4MoE.custom_routing_function + layer.intermediate_size_per_partition = n + layer.ep_rank = 0 + layer.local_num_experts = e + + return TestData( + hidden_states=hidden_states, + w13_quantized=w13_quantized, + w2_quantized=w2_quantized, + a1_scale=a1_scale, + a2_scale=a2_scale, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + layer=layer, + ) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_per_tensor_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( + layer=td.layer, + hidden_states=td.hidden_states, + router_logits=score, + routing_bias=None, + global_num_experts=e, + top_k=topk, + num_expert_group=None, + topk_group=None, + apply_router_weight_on_input=True) + + torch.testing.assert_close(output, + flashinfer_output, + atol=5.5e-2, + rtol=1e-2) + + +@pytest.mark.skip( + "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472" +) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_cutlass_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + td.layer.dp_size = 1 + + flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( + td.hidden_states, + td.layer, + topk_weights, + topk_ids, + activation="silu", + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=True, + ) + + torch.testing.assert_close(output, + flashinfer_cutlass_output, + atol=5.5e-2, + rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 6a9c28b53cd8..feab3f74cac5 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -61,8 +61,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): per_act_token_quant=False, block_shape=None, )) - assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is " - "currently supported.") + assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank @@ -122,7 +122,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): """ aq_m, aq_n = aq.shape workspace2 = () - output_shape = (aq_m, aq_n * 2) + output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ + torch.float8_e4m3fn else (aq_m, aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -151,29 +152,39 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], ): - # Flashinfer CUTLASS kernel takes scalar global scales, - # min because inv_scale. + if self.quant_dtype == torch.float8_e4m3fn: + quant_scales = [ + self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + ] - # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") + a1q_scale = None # not passing input_sf in fp8 + fc1_expert_weights = w1 + fc2_expert_weights = w2 + else: + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + quant_scales = [ + self.a1_gscale, + w1_scale.view(torch.int32), + self.g1_alphas, + self.a2_gscale, + w2_scale.view(torch.int32), + self.g2_alphas, + ] + # FlashInfer API requires weight to be long for nvfp4 + fc1_expert_weights = w1.view(torch.long) + fc2_expert_weights = w2.view(torch.long) - quant_scales = [ - self.a1_gscale, - w1_scale.view(torch.int32), - self.g1_alphas, - self.a2_gscale, - w2_scale.view(torch.int32), - self.g2_alphas, - ] _ = flashinfer_cutlass_fused_moe( input=hidden_states, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, - # FlashInfer API requires weight to be long for nvfp4 - fc1_expert_weights=w1.view(torch.long), - fc2_expert_weights=w2.view(torch.long), + fc1_expert_weights=fc1_expert_weights, + fc2_expert_weights=fc2_expert_weights, output_dtype=self.out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f07be0855492..7c447c2a5348 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -9,6 +9,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -145,7 +149,7 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self, layer.moe_config) + return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig): - super().__init__(moe) + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) - elif self.flashinfer_moe_enabled: + elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) @@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data - if not self.block_quant: - register_moe_scaling_factors(layer) - rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -834,6 +853,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + if self.flashinfer_moe_backend is not None: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + assert not self.block_quant + register_moe_scaling_factors(layer) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + if self.flashinfer_moe_backend == \ + FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + layer.w13_weight.data = w13_weight.data + if self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. @@ -892,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): per_act_token_quant=False, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", @@ -930,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if not self.flashinfer_moe_enabled: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.block_quant: + assert (renormalize and use_grouped_topk + and custom_routing_function is None) + + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) + else: + assert (not renormalize + and custom_routing_function is not None) + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + apply_router_weight_on_input=apply_router_weight_on_input) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -988,63 +1066,38 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) - elif self.flashinfer_moe_enabled: - assert activation == 'silu' - assert scoring_func == 'sigmoid' - if self.block_quant: - assert (renormalize and use_grouped_topk - and custom_routing_function is None) - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: - assert (not renormalize - and custom_routing_function is not None) - return apply_flashinfer_per_tensor_scale_fp8( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=e_score_correction_bias, + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) - elif self.fused_experts is not None: - return self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) else: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 28f16d108834..046234057f04 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum from typing import Any, Callable, Optional, Union import torch @@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -class FlashinferMoeBackend(Enum): - TENSORRT_LLM = "TensorRT-LLM" - CUTLASS = "CUTLASS" - - class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self, layer.moe_config) + return ModelOptFp8MoEMethod(self, layer) return None @@ -278,18 +275,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> None: - super().__init__(moe) + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) + + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.fused_experts is not None or \ + self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def create_weights( self, @@ -433,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) - if self.flashinfer_moe_enabled: + if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) register_moe_scaling_factors(layer) + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) def apply( self, @@ -461,14 +490,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") - if self.flashinfer_moe_enabled: - assert activation == 'silu' + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -495,6 +523,36 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert not renormalize + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) return fused_experts( @@ -951,20 +1009,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.flashinfer_moe_backend = None if self.allow_flashinfer: - flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": - self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - logger.info_once("Using FlashInfer CUTLASS kernels for " - "ModelOptNvFp4FusedMoE.") - elif flashinfer_moe_backend == "latency": - self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM - logger.info_once("Using FlashInfer TensorRT-LLM kernels for " - "ModelOptNvFp4FusedMoE.") - else: - allowed_backends = ["throughput", "latency"] - raise ValueError( - f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}") + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + " for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 278ee5232f47..9889808f0760 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,9 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Optional import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) + +logger = init_logger(__name__) + + +class FlashinferMoeBackend(Enum): + TENSORRT_LLM = "TensorRT-LLM" + CUTLASS = "CUTLASS" + def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): @@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: layer.register_parameter( 'output2_scales_scalar', torch.nn.Parameter(output2_scales, requires_grad=False)) + layer.register_parameter( + 'w2_input_scale_inv', + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + + +def build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp, a1_gscale=layer.w13_input_scale) + + +def select_cutlass_fp8_gemm_impl( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, + out_dtype: Optional[torch.dtype] = None, +) -> mk.FusedMoEPermuteExpertsUnpermute: + """Return a GEMM *experts* implementation for fused-MoE layers""" + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + + if moe is not None: + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=moe.in_dtype, + quant_dtype=torch.float8_e4m3fn, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + + assert out_dtype is not None, ( + "If moe config is None, out_dtype must be passed") + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=out_dtype, + quant_dtype=torch.float8_e4m3fn, + ) + + +def flashinfer_cutlass_moe_fp8( + hidden_states: torch.Tensor, + layer: torch.nn.Module, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, + layer=layer), + select_cutlass_fp8_gemm_impl(moe=None, + layer=layer, + out_dtype=hidden_states.dtype)) + + return fused_experts( + hidden_states, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def get_flashinfer_moe_backend() -> FlashinferMoeBackend: + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_moe_backend == "throughput": + return FlashinferMoeBackend.CUTLASS + elif flashinfer_moe_backend == "latency": + return FlashinferMoeBackend.TENSORRT_LLM + + allowed_backends = ["throughput", "latency"] + raise ValueError( + f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" + f" expected one of {allowed_backends}")