[MoE][Refactor] Make select_experts a non-static method (#29067)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-11-24 13:38:04 -05:00 committed by GitHub
parent cec418b5df
commit 8f066146c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 163 additions and 472 deletions

View File

@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
gating_output=score,
topk=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
gating_output=score,
topk=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(

View File

@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
integration tests with FusedMoE layer.
"""
import tempfile
import pytest
import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting,
RoutingSimulator,
@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
# Test different routing strategies
strategies = RoutingSimulator.get_available_strategies()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
fused_moe = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=0,
use_grouped_topk=False,
renormalize=True,
)
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = fused_moe.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=False,
renormalize=True,
indices_type=torch.long,
)
# Verify output shapes

View File

@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def allow_inplace(self) -> bool:
return False
@property
def method_name(self) -> str:
return self.__class__.__name__
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,

View File

@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
@property
def method_name(self) -> str:
return self.old_quant_method.method_name
def create_weights(
self,
layer: torch.nn.Module,
@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Is getattr needed?
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if enable_eplb:
if self.supports_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
else:
raise NotImplementedError(
"EPLB is not supported for "
f"{self.old_quant_method.__class__.__name__}."
)
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)

View File

@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp):
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
@staticmethod
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
enable_eplb: bool = False,
expert_map: torch.Tensor | None = None,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
global_num_experts: int | None = None,
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp):
fused_topk_bias,
)
if self.enable_eplb:
if self.quant_method.supports_eplb:
if self.expert_load_view is None:
raise ValueError(
"enable_eplb=True requiere expert_load_view != None"
)
if self.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if self.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
else:
raise NotImplementedError(
f"EPLB is not supported for {self.quant_method.method_name}."
)
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp):
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=top_k,
top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
elif self.use_grouped_topk:
assert self.topk_group is not None
assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0
assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
elif e_score_correction_bias is not None:
elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=e_score_correction_bias.data,
topk=top_k,
renormalize=renormalize,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
elif custom_routing_function is None:
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
else:
topk_weights, topk_ids = custom_routing_function(
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
topk=self.top_k,
renormalize=self.renormalize,
)
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp):
# Compute zero expert result if needed
if (
zero_expert_num is not None
and zero_expert_num > 0
and zero_expert_type is not None
and global_num_experts is not None
self.zero_expert_num is not None
and self.zero_expert_num > 0
and self.zero_expert_type is not None
and self.global_num_experts is not None
):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
num_experts=self.global_num_experts,
zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
else:

View File

@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
)
if self.rocm_aiter_moe_enabled:
@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,

View File

@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(

View File

@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
# TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:

View File

@ -511,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -532,16 +532,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
)
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@ -554,19 +555,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@ -1109,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -1130,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=layer.num_fused_shared_experts,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
@ -1377,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -1398,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_experts(
@ -1738,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -1759,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
)
assert activation == "silu", f"{activation} not supported for Marlin MoE."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
@ -2001,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -2022,43 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
if expert_load_view is None:
raise ValueError("enable_eplb=True requiere expert_load_view != None")
if logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
if not isinstance(layer, FusedMoE):
raise TypeError(
"EPLB is only supported when `layer` is a instance of FusedMoE."
)
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
return fused_experts(

View File

@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_experts(

View File

@ -1140,7 +1140,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -1216,31 +1216,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
)
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
select_result = FusedMoE.select_experts(
select_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
)
topk_weights, topk_ids, zero_expert_result = select_result
@ -1322,7 +1300,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.allow_cutlass_block_scaled_grouped_gemm
),
)
if zero_expert_num != 0 and zero_expert_type is not None:
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)

View File

@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_moe_gguf(
x,

View File

@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
)
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(

View File

@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
@ -1459,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -1480,16 +1469,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@ -1502,19 +1491,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.use_marlin:

View File

@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_experts(

View File

@ -862,7 +862,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -887,18 +887,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return fused_marlin_moe(
@ -989,17 +980,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
# Backend-specific preparation

View File

@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if not self.emulate:

View File

@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
topk_weights, topk_ids, _ = FusedMoE.select_experts(
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(