mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:34:59 +08:00
[MoE][Refactor] Make select_experts a non-static method (#29067)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
cec418b5df
commit
8f066146c3
@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax",
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax",
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
|
||||
@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
|
||||
integration tests with FusedMoE layer.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import (
|
||||
DistributionBasedRouting,
|
||||
RoutingSimulator,
|
||||
@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
# Test different routing strategies
|
||||
strategies = RoutingSimulator.get_available_strategies()
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
)
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
fused_moe = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=0,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
for strategy in strategies:
|
||||
# Set environment variable
|
||||
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
|
||||
@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = fused_moe.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
indices_type=torch.long,
|
||||
)
|
||||
|
||||
# Verify output shapes
|
||||
|
||||
@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def allow_inplace(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
|
||||
@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def allow_inplace(self) -> bool:
|
||||
return self.old_quant_method.allow_inplace
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.old_quant_method.method_name
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# Is getattr needed?
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
if enable_eplb:
|
||||
if self.supports_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"EPLB is not supported for "
|
||||
f"{self.old_quant_method.__class__.__name__}."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
)
|
||||
|
||||
result = self.fused_experts(
|
||||
@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
|
||||
@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp):
|
||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
enable_eplb: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
global_num_experts: int | None = None,
|
||||
zero_expert_num: int | None = None,
|
||||
zero_expert_type: str | None = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp):
|
||||
fused_topk_bias,
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
if self.quant_method.supports_eplb:
|
||||
if self.expert_load_view is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere expert_load_view != None"
|
||||
)
|
||||
if self.logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_to_physical_map != None"
|
||||
)
|
||||
if self.logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_replica_count != None"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"EPLB is not supported for {self.quant_method.method_name}."
|
||||
)
|
||||
|
||||
indices_type = self.quant_method.topk_indices_dtype
|
||||
|
||||
# Check if we should use a routing simulation strategy
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp):
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
top_k=top_k,
|
||||
top_k=self.top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
# DeepSeekv2 uses grouped_top_k
|
||||
elif use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
elif self.use_grouped_topk:
|
||||
assert self.topk_group is not None
|
||||
assert self.num_expert_group is not None
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert num_fused_shared_experts == 0
|
||||
assert self.num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp):
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
elif e_score_correction_bias is not None:
|
||||
elif self.e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=e_score_correction_bias.data,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
elif custom_routing_function is None:
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
elif self.custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
topk_weights, topk_ids = self.custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
|
||||
if self.enable_eplb:
|
||||
topk_ids = eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if (indices_type is not None) and topk_ids.dtype != indices_type:
|
||||
@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Compute zero expert result if needed
|
||||
if (
|
||||
zero_expert_num is not None
|
||||
and zero_expert_num > 0
|
||||
and zero_expert_type is not None
|
||||
and global_num_experts is not None
|
||||
self.zero_expert_num is not None
|
||||
and self.zero_expert_num > 0
|
||||
and self.zero_expert_type is not None
|
||||
and self.global_num_experts is not None
|
||||
):
|
||||
zero_expert_result = zero_experts_compute_triton(
|
||||
expert_indices=topk_ids,
|
||||
expert_scales=topk_weights,
|
||||
num_experts=global_num_experts,
|
||||
zero_expert_type=zero_expert_type,
|
||||
num_experts=self.global_num_experts,
|
||||
zero_expert_type=self.zero_expert_type,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
|
||||
@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
|
||||
@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
# TODO(bnell): Do these need to be called on the hot path?
|
||||
if self.quant_config.load_in_8bit:
|
||||
w13, w2 = self._apply_8bit_dequant(layer)
|
||||
else:
|
||||
|
||||
@ -511,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -532,16 +532,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
|
||||
)
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
@ -554,19 +555,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
@ -1109,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -1130,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
@ -1377,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -1398,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
@ -1738,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -1759,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
|
||||
)
|
||||
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
@ -2001,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -2022,43 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
if expert_load_view is None:
|
||||
raise ValueError("enable_eplb=True requiere expert_load_view != None")
|
||||
if logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_to_physical_map != None"
|
||||
)
|
||||
if logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_replica_count != None"
|
||||
)
|
||||
if not isinstance(layer, FusedMoE):
|
||||
raise TypeError(
|
||||
"EPLB is only supported when `layer` is a instance of FusedMoE."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
|
||||
@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
|
||||
@ -1140,7 +1140,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -1216,31 +1216,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
select_result = FusedMoE.select_experts(
|
||||
select_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = select_result
|
||||
@ -1322,7 +1300,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.allow_cutlass_block_scaled_grouped_gemm
|
||||
),
|
||||
)
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
|
||||
@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"fused GGUF MoE method."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
return fused_moe_gguf(
|
||||
x,
|
||||
|
||||
@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
|
||||
)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
|
||||
@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
|
||||
)
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
@ -1459,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -1480,16 +1469,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
|
||||
)
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
|
||||
)
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
@ -1502,19 +1491,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
|
||||
@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
|
||||
@ -862,7 +862,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -887,18 +887,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
@ -989,17 +980,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Backend-specific preparation
|
||||
|
||||
@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if not self.emulate:
|
||||
|
||||
@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user