diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9e5aa4e4c2a89..9131582eef754 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -255,7 +255,7 @@ class DeviceCommunicatorBase: if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize() + module.quant_method.init_prepare_finalize(module) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 7c1a7b636a9c2..cab610decf901 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -450,6 +450,12 @@ class FusedMoEConfig: if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Config) + if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + quant_dtype = "mxfp8" + from vllm.model_executor.layers.quantization.modelopt import ( ModelOptNvFp4Config) if quant_dtype is None and isinstance(quant_config, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 54406a5a2d87f..b9de03ddd216e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): # Note: init_prepare_finalize should only be called by # prepare_communication_buffer_for_model. - def init_prepare_finalize(self): + def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None prepare_finalize = self.maybe_make_prepare_finalize(self.moe) @@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): assert self.fused_experts is None, \ f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -221,6 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation @@ -273,6 +274,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): prepare_finalize: FusedMoEPrepareAndFinalize, # TODO(bnell): Remove. Every layer should have an moe config object. moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py new file mode 100644 index 0000000000000..14dfce4b0e3aa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.utils import next_power_of_2 + + +class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + moe: FusedMoEConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + w13_bias, + w2_bias, + max_capture_size, + ): + super().__init__(moe.quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.w13_bias = w13_bias + self.w2_bias = w2_bias + self.max_capture_size = max_capture_size + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # The workspaces for this implementation are managed by flashinfer. + # TODO(varun) : workspace1 is could be used as the output tensor. This + # is error-prone. Allow the `workspace_shapes` to return None workspaces + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output, a.dtype) + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, + local_num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # 1.0 means perfect expert distribution. + # > 1.0 means some experts have more tokens than the perfect + # distribution. + # < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert assuming perfect + # distribution. + num_tokens_per_expert = (num_tokens * top_k) // local_num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16).view(torch.int16) + + assert w1_scale is not None + assert w2_scale is not None + kwargs = { + "topk_ids": + packed_tensor, + "routing_bias": + None, + "hidden_states": + x_quant, + "hidden_states_scale": + x_scale, + "gemm1_weights": + w1, + "gemm1_weights_scale": + w1_scale, + "gemm1_bias": + self.w13_bias, + "gemm1_alpha": + self.gemm1_alpha, + "gemm1_beta": + self.gemm1_beta, + "gemm1_clamp_limit": + self.gemm1_clamp_limit, + "gemm2_weights": + w2, + "gemm2_weights_scale": + w2_scale, + "gemm2_bias": + self.w2_bias, + "output1_scale_scalar": + None, + "output1_scale_gate_scalar": + None, + "output2_scale_scalar": + None, + "num_experts": + global_num_experts, + "top_k": + topk, + "n_group": + None, + "topk_group": + None, + "intermediate_size": + intermediate_size, + "local_expert_offset": + local_expert_offset, + "local_num_experts": + local_num_experts, + "routed_scaling_factor": + None, + "tile_tokens_dim": + self._get_tile_tokens_dim(x_quant, topk, local_num_experts), + "routing_method_type": + 1, + "do_finalize": + True, + "output": + output, + "tune_max_num_tokens": + self.max_capture_size, + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 4c3e700ad3990..1aeb3f92bc3ea 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + mxfp8_quantize) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -177,6 +179,18 @@ def _mxfp4_quantize( return A, None +def _mxfp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert A_scale is None + assert not per_act_token_quant + assert block_shape is None + return mxfp8_quantize(A) + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -195,6 +209,8 @@ def moe_kernel_quantize_input( is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp8": + return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6279bb8b60570..af9d1c46f68f4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return the appropriate GEMM experts implementation.""" experts = select_nvfp4_gemm_impl( @@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): dtype=torch.int64) def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - ) -> FusedMoEPermuteExpertsUnpermute: + self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: # cutlass path if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index be358cfa949f0..0200b0e9ed001 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 72864853f7e0c..adce598c4ff1f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: experts = select_cutlass_fp8_gemm_impl( moe, @@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: experts = select_nvfp4_gemm_impl( moe, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bdeb169a4b97f..6724796904f01 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + if (prepare_finalize.activation_format == + mk.FusedMoEActivationFormat.BatchedExperts): + raise NotImplementedError( + "Mxfp4 does not support batched experts format for EP") + else: + if should_use_flashinfer_mxfp4(): + # B200 code-path + kwargs = { + "gemm1_alpha": layer.gemm1_alpha, + "gemm1_beta": layer.gemm1_beta, + "gemm1_clamp_limit": layer.gemm1_clamp_limit, + "w13_bias": layer.w13_bias, + "w2_bias": layer.w2_bias, + "max_capture_size": self.max_capture_size, + } + return TrtLlmGenExperts(moe, **kwargs) + else: + # Use matmul_ogs from triton_kernels here! + raise NotImplementedError( + "Mxfp4 does not support non-batched experts format for EP") + + def _route_and_experts( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None + ) -> torch.Tensor: + + assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + def apply( self, layer: torch.nn.Module, @@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=activation, expert_map=expert_map) + if self.fused_experts is not None: + return self._route_and_experts( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) + assert _can_support_mxfp4( use_grouped_topk, topk_group, num_expert_group, expert_map, custom_routing_function, e_score_correction_bias, diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 48f9cc3737e47..3de928fea7202 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None): return not (use_grouped_topk or topk_group or num_expert_group - or expert_map or custom_routing_function - or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swigluoai" - or expert_load_view or logical_to_physical_map - or logical_replica_count) + or custom_routing_function or e_score_correction_bias + or apply_router_weight_on_input or scoring_func != "softmax" + or activation != "swigluoai" or expert_load_view + or logical_to_physical_map or logical_replica_count) def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py new file mode 100644 index 0000000000000..2a6b21c918f46 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + + try: + from flashinfer import mxfp8_quantize + except ImportError as err: + raise ImportError("The package `flashinfer` is required to do " + "MX-FP8 quantization. Please install it with" \ + "`pip install flashinfer`") from err + + return mxfp8_quantize(x, is_sf_swizzled_layout=False)