diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 638741e91619..a6977f222408 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, @@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids = Llama4MoE.custom_routing_function( hidden_states=td.hidden_states, - router_logits=score, - use_grouped_topk=False, - top_k=topk, + gating_output=score, + topk=topk, renormalize=False, - custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax", ) quant_config = fp8_w8a8_moe_quant_config( @@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids = Llama4MoE.custom_routing_function( hidden_states=td.hidden_states, - router_logits=score, - use_grouped_topk=False, - top_k=topk, + gating_output=score, + topk=topk, renormalize=False, - custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax", ) quant_config = fp8_w8a8_moe_quant_config( diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 5a162fa8f791..e8826eb441a2 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including integration tests with FusedMoE layer. """ +import tempfile + import pytest import torch +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.fused_moe.routing_simulator import ( DistributionBasedRouting, RoutingSimulator, @@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device): # Test different routing strategies strategies = RoutingSimulator.get_available_strategies() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + ) + initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ) + fused_moe = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=0, + use_grouped_topk=False, + renormalize=True, + ) + for strategy in strategies: # Set environment variable env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" @@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = fused_moe.select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=top_k, - use_grouped_topk=False, - renormalize=True, - indices_type=torch.long, ) # Verify output shapes diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 073e90a4e680..ef7090c349fc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase): def allow_inplace(self) -> bool: return False + @property + def method_name(self) -> str: + return self.__class__.__name__ + @abstractmethod def apply( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, top_k: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index c6dc95acdb63..c23c41df226f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def allow_inplace(self) -> bool: return self.old_quant_method.allow_inplace + @property + def method_name(self) -> str: + return self.old_quant_method.method_name + def create_weights( self, layer: torch.nn.Module, @@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def apply( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Is getattr needed? - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - if enable_eplb: - if self.supports_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - else: - raise NotImplementedError( - "EPLB is not supported for " - f"{self.old_quant_method.__class__.__name__}." - ) - topk_weights, topk_ids, zero_expert_result = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, ) result = self.fused_experts( @@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): expert_map=None if self.disable_expert_map else expert_map, ) - if zero_expert_num != 0 and zero_expert_type is not None: + if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: assert not isinstance(result, tuple), ( "Shared + zero experts are mutually exclusive not yet supported" ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6619b64b2bbc..0ef3130b2633 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp): logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() ) - @staticmethod def select_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - indices_type: torch.dtype | None = None, - enable_eplb: bool = False, - expert_map: torch.Tensor | None = None, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - global_num_experts: int | None = None, - zero_expert_num: int | None = None, - zero_expert_type: str | None = None, - num_fused_shared_experts: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ Route the input hidden states to the top-k experts based on the router logits. @@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp): fused_topk_bias, ) + if self.enable_eplb: + if self.quant_method.supports_eplb: + if self.expert_load_view is None: + raise ValueError( + "enable_eplb=True requiere expert_load_view != None" + ) + if self.logical_to_physical_map is None: + raise ValueError( + "enable_eplb=True requiere logical_to_physical_map != None" + ) + if self.logical_replica_count is None: + raise ValueError( + "enable_eplb=True requiere logical_replica_count != None" + ) + else: + raise NotImplementedError( + f"EPLB is not supported for {self.quant_method.method_name}." + ) + + indices_type = self.quant_method.topk_indices_dtype + # Check if we should use a routing simulation strategy routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY if routing_strategy != "": @@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp): hidden_states=hidden_states, router_logits=router_logits, strategy_name=routing_strategy, - top_k=top_k, + top_k=self.top_k, indices_type=indices_type, ) # DeepSeekv2 uses grouped_top_k - elif use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None + elif self.use_grouped_topk: + assert self.topk_group is not None + assert self.num_expert_group is not None if rocm_aiter_ops.is_fused_moe_enabled(): if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): - assert num_fused_shared_experts == 0 + assert self.num_fused_shared_experts == 0 grouped_topk_impl = partial( rocm_aiter_grouped_topk, - num_fused_shared_experts=num_fused_shared_experts, + num_fused_shared_experts=self.num_fused_shared_experts, ) else: grouped_topk_impl = grouped_topk @@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp): topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, ) - elif e_score_correction_bias is not None: + elif self.e_score_correction_bias is not None: topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, - e_score_correction_bias=e_score_correction_bias.data, - topk=top_k, - renormalize=renormalize, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, ) - if routed_scaling_factor != 1.0: - topk_weights *= routed_scaling_factor - elif custom_routing_function is None: + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + elif self.custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, + topk=self.top_k, + renormalize=self.renormalize, indices_type=indices_type, ) else: - topk_weights, topk_ids = custom_routing_function( + topk_weights, topk_ids = self.custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, + topk=self.top_k, + renormalize=self.renormalize, ) - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - + if self.enable_eplb: topk_ids = eplb_map_to_physical_and_record( topk_ids=topk_ids, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if (indices_type is not None) and topk_ids.dtype != indices_type: @@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp): # Compute zero expert result if needed if ( - zero_expert_num is not None - and zero_expert_num > 0 - and zero_expert_type is not None - and global_num_experts is not None + self.zero_expert_num is not None + and self.zero_expert_num > 0 + and self.zero_expert_type is not None + and self.global_num_experts is not None ): zero_expert_result = zero_experts_compute_triton( expert_indices=topk_ids, expert_scales=topk_weights, - num_experts=global_num_experts, - zero_expert_type=zero_expert_type, + num_experts=self.global_num_experts, + zero_expert_type=self.zero_expert_type, hidden_states=hidden_states, ) else: diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 63b0e6f573d6..48e5a8907f92 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_cuda( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, use_grouped_topk: bool, top_k: int, @@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - topk_weights, topk_ids, zero_expert_result = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - num_fused_shared_experts=layer.num_fused_shared_experts, ) if self.rocm_aiter_moe_enabled: @@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, ) - if zero_expert_num != 0 and zero_expert_type is not None: + if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: assert not isinstance(result, tuple), ( "Shared + zero experts are mutually exclusive not yet supported" ) @@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_cpu( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, use_grouped_topk: bool, top_k: int, @@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_xpu( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, use_grouped_topk: bool, top_k: int, @@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_tpu( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, use_grouped_topk: bool, top_k: int, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 3f6ea68072b4..66945e2d2a7c 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.") - assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_marlin_moe( diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index e5a741e639ad..1e57fa218b79 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `BitsAndBytesMoEMethod` yet." - ) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) + # TODO(bnell): Do these need to be called on the hot path? if self.quant_config.load_in_8bit: w13, w2 = self._apply_8bit_dequant(layer) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ad547dd40982..149e4419c64a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -511,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -532,16 +532,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." - ) assert activation == "silu", "Only SiLU activation is supported." if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." + ) + return flashinfer_trtllm_fp4_moe( layer=layer, x=x, @@ -554,19 +555,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): e_score_correction_bias=e_score_correction_bias, ) - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) if self.use_marlin: @@ -1109,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -1130,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - num_fused_shared_experts=layer.num_fused_shared_experts, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN @@ -1377,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -1398,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." - ) - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_experts( @@ -1738,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -1759,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet." - ) - assert activation == "silu", f"{activation} not supported for Marlin MoE." - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_marlin_moe( @@ -2001,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -2022,43 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - if expert_load_view is None: - raise ValueError("enable_eplb=True requiere expert_load_view != None") - if logical_to_physical_map is None: - raise ValueError( - "enable_eplb=True requiere logical_to_physical_map != None" - ) - if logical_replica_count is None: - raise ValueError( - "enable_eplb=True requiere logical_replica_count != None" - ) - if not isinstance(layer, FusedMoE): - raise TypeError( - "EPLB is only supported when `layer` is a instance of FusedMoE." - ) - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0), - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, ) return fused_experts( diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 5241f9a2301b..7ebe40ec8468 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `ExpertsInt8MoEMethod` yet." - ) - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 91bd45bf879c..9e2718057038 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1140,7 +1140,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -1216,31 +1216,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, ) - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - select_result = FusedMoE.select_experts( + select_result = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - num_fused_shared_experts=layer.num_fused_shared_experts, ) topk_weights, topk_ids, zero_expert_result = select_result @@ -1322,7 +1300,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.allow_cutlass_block_scaled_grouped_gemm ), ) - if zero_expert_num != 0 and zero_expert_type is not None: + + if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: assert not isinstance(result, tuple), ( "Shared + zero experts are mutually exclusive not yet supported" ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 42d7a67371ae..bcdfafb50fc5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.") - assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( @@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): "fused GGUF MoE method." ) - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_moe_gguf( x, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 68a122fd46c6..77b15db373a3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `GPTQMarlinMoEMethod` yet." - ) - assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_marlin_moe( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 01a23168bdde..816567313591 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `ModelOptFp8MoEMethod` yet." - ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptFp8MoEMethod` yet." + ) assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) @@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) # Expert selection - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: @@ -1459,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -1480,16 +1469,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." - ) assert activation == "silu", "Only SiLU activation is supported." if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." + ) return flashinfer_trtllm_fp4_moe( layer=layer, x=x, @@ -1502,19 +1491,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, ) - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) if self.use_marlin: diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 2090c86f78dc..cf348290a271 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_experts( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 66ae2e94c60a..255b5aad1785 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -862,7 +862,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -887,18 +887,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, ) return fused_marlin_moe( @@ -989,17 +980,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, ) # Backend-specific preparation diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 30772c3665b0..8be0299eaa66 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet." - ) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) if self.rocm_aiter_moe_enabled: @@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." - ) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) if not self.emulate: diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 52656263a601..7b51b828009f 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") - - topk_weights, topk_ids, _ = FusedMoE.select_experts( + topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, ) return fused_marlin_moe(