mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 17:24:29 +08:00
[xpu]support moe models on XPU platform (#21643)
Signed-off-by: yan <yan.ma@intel.com> Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
4abfd8796f
commit
73e1b9b1d4
@ -327,7 +327,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.is_xpu():
|
||||
import intel_extension_for_pytorch as ipex
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
use_prepack=True,
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
|
||||
dtype = layer.w13_weight.dtype
|
||||
@ -509,6 +516,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation,
|
||||
)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if enable_eplb is not False or expert_load_view is not None or \
|
||||
logical_to_physical_map is not None or \
|
||||
logical_replica_count is not None:
|
||||
raise NotImplementedError("Expert load balancing is not supported "
|
||||
"for XPU.")
|
||||
assert custom_routing_function is None
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user