[MoE][Refactor] Remove most arguments to FusedMoEMethodBase.apply (#29066)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
bnellnm 2025-12-09 16:48:25 -05:00 committed by GitHub
parent 7618dc973d
commit 00e5cbb967
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 318 additions and 872 deletions

View File

@ -4,7 +4,10 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
@ -49,6 +52,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute", "FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat", "FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE", "SharedFusedMoE",
"activation_without_mul", "activation_without_mul",
"override_config", "override_config",

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable
import torch import torch
@ -100,22 +99,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
@ -97,23 +96,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -127,10 +109,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=self.allow_inplace, inplace=self.allow_inplace,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:

View File

@ -33,10 +33,6 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data, init_aiter_topK_meta_data,
) )
@ -57,11 +53,8 @@ from vllm.utils.torch_utils import (
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record, fused_experts from .fused_moe import eplb_map_to_physical_and_record
else: else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = object # type: ignore
FusedMoEPrepareAndFinalize = object # type: ignore
def _eplb_map_to_physical_and_record( def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
@ -483,7 +476,7 @@ class FusedMoE(CustomOp):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
) )
self.expert_map: torch.Tensor | None self._expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
@ -493,7 +486,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables() self._maybe_init_expert_routing_tables()
logger.info_once( logger.info_once(
@ -506,10 +499,10 @@ class FusedMoE(CustomOp):
self.expert_placement_strategy, self.expert_placement_strategy,
self.local_num_experts, self.local_num_experts,
self.global_num_experts, self.global_num_experts,
get_compressed_expert_map(self.expert_map), get_compressed_expert_map(self._expert_map),
) )
else: else:
self.local_num_experts, self.expert_map, self.expert_mask = ( self.local_num_experts, self._expert_map, self.expert_mask = (
self.global_num_experts, self.global_num_experts,
None, None,
None, None,
@ -781,7 +774,7 @@ class FusedMoE(CustomOp):
), ),
) )
if self.expert_map is None: if self._expert_map is None:
return None return None
routing_tables = self.ensure_round_robin_expert_routing_tables( routing_tables = self.ensure_round_robin_expert_routing_tables(
@ -789,7 +782,7 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts, local_num_experts=self.local_num_experts,
device=self.expert_map.device, device=self._expert_map.device,
) )
global_to_physical, physical_to_global, local_global = routing_tables global_to_physical, physical_to_global, local_global = routing_tables
@ -840,8 +833,8 @@ class FusedMoE(CustomOp):
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self._expert_map is not None
with self.expert_map.device: with self._expert_map.device:
local_num_experts, expert_map, expert_mask = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
@ -851,7 +844,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables() self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled: if self.aiter_fmoe_shared_expert_enabled:
@ -1068,9 +1061,9 @@ class FusedMoE(CustomOp):
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None: if self._expert_map is None:
return expert_id return expert_id
return self.expert_map[expert_id].item() return self._expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer( def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int self, vllm_config: VllmConfig, dp_size: int
@ -1744,6 +1737,12 @@ class FusedMoE(CustomOp):
reduce_output(fused_output)[..., :og_hidden_states], reduce_output(fused_output)[..., :og_hidden_states],
) )
@property
def expert_map(self) -> torch.Tensor | None:
return (
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
)
def forward_cuda( def forward_cuda(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1805,24 +1804,6 @@ class FusedMoE(CustomOp):
layer=self, layer=self,
x=staged_hidden_states, x=staged_hidden_states,
router_logits=staged_router_logits, router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
if has_separate_shared_experts: if has_separate_shared_experts:
@ -1968,25 +1949,6 @@ class FusedMoE(CustomOp):
if do_naive_dispatch_combine if do_naive_dispatch_combine
else hidden_states, else hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
if has_separate_shared_experts: if has_separate_shared_experts:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -269,53 +268,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
return self.forward( return self.forward(
x=x,
layer=layer, layer=layer,
x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
@ -333,24 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -364,9 +307,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
expert_map=expert_map, expert_map=layer.expert_map,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
elif self.flashinfer_cutlass_moe_enabled: elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe( return self.flashinfer_cutlass_moe(
@ -375,8 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
result = fused_experts( result = fused_experts(
@ -386,11 +329,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
@ -405,148 +348,101 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for CPU.") raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe( return layer.cpu_fused_moe(
layer, layer,
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
global_num_experts, layer.global_num_experts,
expert_map, layer.expert_map,
custom_routing_function, layer.custom_routing_function,
scoring_func, layer.scoring_func,
routed_scaling_factor, layer.routed_scaling_factor,
e_score_correction_bias, layer.e_score_correction_bias,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
activation, layer.activation,
) )
def forward_xpu( def forward_xpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for XPU.") raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
) )
def forward_tpu( def forward_tpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not use_grouped_topk assert not layer.use_grouped_topk
assert num_expert_group is None assert layer.num_expert_group is None
assert topk_group is None assert layer.topk_group is None
assert custom_routing_function is None assert layer.custom_routing_function is None
assert apply_router_weight_on_input is False assert layer.apply_router_weight_on_input is False
if scoring_func != "softmax": if layer.scoring_func != "softmax":
raise NotImplementedError( raise NotImplementedError(
"Only softmax scoring function is supported for TPU." "Only softmax scoring function is supported for TPU."
) )
if e_score_correction_bias is not None: if layer.e_score_correction_bias is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert score correction bias is not supported for TPU." "Expert score correction bias is not supported for TPU."
) )
assert activation == "silu", f"{activation} is not supported for TPU." assert layer.activation == "silu", (
assert routed_scaling_factor == 1.0, ( f"{layer.activation} is not supported for TPU."
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." )
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
) )
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for TPU.") raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas( return fused_moe_pallas(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk=top_k, topk=layer.top_k,
gating_output=router_logits, gating_output=router_logits,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
renormalize=renormalize, renormalize=layer.renormalize,
) )
if current_platform.is_tpu(): if current_platform.is_tpu():

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
@ -669,25 +668,8 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -708,9 +690,9 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace, workspace=layer.workspace,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Union from typing import Any, Union
import torch import torch
@ -498,23 +497,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -534,10 +516,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
from collections.abc import Callable
from enum import Enum from enum import Enum
import torch import torch
@ -558,31 +557,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if ( if (
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
) )
@ -591,12 +573,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, top_k=layer.top_k,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
@ -619,9 +601,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -646,15 +628,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert expert_map is None, ( assert layer.expert_map is None, (
"Expert Parallelism / expert_map " "Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod." "CompressedTensorsW4A4Nvfp4MoEMethod."
@ -670,7 +652,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments # TODO(bnell): derive these from arguments
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
@ -1188,23 +1170,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -1215,7 +1180,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
if self.use_marlin: if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -1228,9 +1195,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -1248,9 +1215,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1270,10 +1237,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None
if self.disable_expert_map
else layer.expert_map, # ???
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
else: else:
@ -1290,9 +1259,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
ab_strides1=self.ab_strides1_c_strides2, ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
@ -1314,10 +1283,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1437,23 +1406,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -1469,10 +1421,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1814,25 +1766,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -1853,9 +1790,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx, g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx, g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
@ -2057,23 +1994,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -2089,10 +2009,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -2372,32 +2292,15 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in ("silu", "swigluoai", "swiglu"), ( assert layer.activation in ("silu", "swigluoai", "swiglu"), (
"Only SiLU/SwiGLUGU/SwiGLUUG are supported." "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
) )
assert expert_map is None, """expert_map/EP not implemented assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE.""" for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int: def _act_kind(s: str) -> int:
@ -2414,15 +2317,9 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, top_k=layer.top_k,
top_k=top_k, use_grouped_topk=layer.use_grouped_topk,
renormalize=renormalize, renormalize=layer.renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
) )
return torch.ops._C.dynamic_4bit_int_moe( return torch.ops._C.dynamic_4bit_int_moe(
@ -2435,8 +2332,8 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_in_features, layer.w2_in_features,
layer.w13_out_features, layer.w13_out_features,
layer.group_size, layer.group_size,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
int(_act_kind(activation)), int(_act_kind(layer.activation)),
) )
@ -2707,28 +2604,11 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
) )
@ -2749,9 +2629,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
a_strides1=self.a_strides1_c_strides2, a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2, a_strides2=self.a_strides2,
b_strides1=self.b_strides1, b_strides1=self.b_strides1,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -140,23 +139,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -172,10 +154,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -1242,41 +1241,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == "silu", ( if layer.enable_eplb:
f"Expected 'silu' activation but got {activation}" raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
) )
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = ( e_score_correction_bias = (
e_score_correction_bias.to(x.dtype) layer.e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None if layer.e_score_correction_bias is not None
else None else None
) )
routing_method_type = layer.routing_method_type routing_method_type = layer.routing_method_type
@ -1290,29 +1268,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale_inv=layer.w13_weight_scale_inv, w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight, w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv, w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition, intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
assert not renormalize and custom_routing_function is not None assert (
not layer.renormalize and layer.custom_routing_function is not None
)
result = apply_flashinfer_per_tensor_scale_fp8( result = apply_flashinfer_per_tensor_scale_fp8(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
routing_bias=e_score_correction_bias, routing_bias=layer.e_score_correction_bias,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
select_result = layer.select_experts( select_result = layer.select_experts(
@ -1333,13 +1313,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe( result = fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -1352,20 +1334,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation == "silu", ( assert layer.activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
if not self.block_quant: if not self.block_quant:
assert not renormalize and custom_routing_function is not None assert (
assert scoring_func == "sigmoid", ( not layer.renormalize and layer.custom_routing_function is not None
f"Expected 'sigmoid' scoring func but got {scoring_func}" )
assert layer.scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
) )
# Delegate to CUTLASS FlashInfer path; function already bound with # Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable # use_deepseek_fp8_block_scale for block-quant when applicable
@ -1375,10 +1359,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -1390,10 +1374,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=( allow_cutlass_block_scaled_grouped_gemm=(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
@ -625,26 +625,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input: if layer.apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
"fused GGUF MoE method." "fused GGUF MoE method."
@ -662,7 +645,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, layer.w2_qweight_type.weight_type,
activation, layer.activation,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional from typing import Any, Optional
@ -790,25 +789,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -829,9 +811,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -440,31 +439,14 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fnmatch import fnmatch from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -707,43 +706,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet." "EPLB not supported for `ModelOptFp8MoEMethod` yet."
) )
assert activation == "silu", ( assert layer.activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
assert not renormalize
assert not layer.renormalize
return apply_flashinfer_per_tensor_scale_fp8( return apply_flashinfer_per_tensor_scale_fp8(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
routing_bias=e_score_correction_bias, routing_bias=layer.e_score_correction_bias,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
# Expert selection # Expert selection
@ -753,9 +736,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation in ("silu", "relu2_no_mul"), ( assert layer.activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul')," "Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {activation}" f"but got {layer.activation}"
) )
return flashinfer_cutlass_moe_fp8( return flashinfer_cutlass_moe_fp8(
x, x,
@ -763,10 +746,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
@ -780,11 +763,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
@ -1504,23 +1487,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.moe.is_act_and_mul: if not self.moe.is_act_and_mul:
assert ( assert (
@ -1535,7 +1501,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet." "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
) )
@ -1543,12 +1509,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, top_k=layer.top_k,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
@ -1571,9 +1537,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
) )
@ -1604,10 +1570,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
@ -1622,8 +1588,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO: derive from arguments # TODO: derive from arguments
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_scale1=None, global_scale1=None,
global_scale2=None, global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
activation=activation, activation=layer.activation,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(
use_grouped_topk, layer.use_grouped_topk,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
expert_map, layer.expert_map,
custom_routing_function, layer.custom_routing_function,
e_score_correction_bias, layer.e_score_correction_bias,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
scoring_func, layer.scoring_func,
activation, layer.activation,
expert_load_view, layer.expert_load_view,
logical_to_physical_map, layer.logical_to_physical_map,
logical_replica_count, layer.logical_replica_count,
), "MXFP4 are not supported with this configuration." ), "MXFP4 are not supported with this configuration."
if ( if (
@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar None, # output1_scale_scalar
None, # output1_scale_gate_scalar None, # output1_scale_gate_scalar
None, # output2_scale_scalar None, # output2_scale_scalar
global_num_experts, layer.global_num_experts,
top_k, layer.top_k,
None, # n_group None, # n_group
None, # topk_group None, # topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts, # local num experts self.num_experts, # local num experts
None, None,
None, None,
1 if renormalize else 0, # routing_method_type, renormalize 1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=layer.top_k,
renormalize=renormalize, renormalize=layer.renormalize,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "swigluoai", ( assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE" "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
) )
hidden_size_pad = round_up(self.original_hidden_size, 128) hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion( hidden_states = layer.ipex_fusion(
x_pad, x_pad,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
activation="swiglu_oai", activation="swiglu_oai",
) )
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
) )
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -631,9 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -645,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
return out return out

View File

@ -3,7 +3,6 @@
# Copyright © 2025, Oracle and/or its affiliates. # Copyright © 2025, Oracle and/or its affiliates.
import os import os
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_config.quant_type.id, quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
workspace=workspace, workspace=workspace,
) )