mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[fix][cpu] Use a SwigluOAI impl which supports interleaved gate-up wei (#29273)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
parent
64deead719
commit
98caeadd54
@ -6,22 +6,7 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def swigluoai_and_mul(
|
||||
x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
gate, up = x[..., :d], x[..., d:]
|
||||
gate = gate.clamp(max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(alpha * gate)
|
||||
return (up + 1) * glu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
|
||||
|
||||
|
||||
def grouped_topk(
|
||||
@ -227,6 +212,11 @@ class CPUFusedMOE:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
|
||||
self.act_to_impl = {
|
||||
"silu": SiluAndMul(),
|
||||
"swigluoai": SwigluOAIAndMul(),
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -246,7 +236,7 @@ class CPUFusedMOE:
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
|
||||
assert activation in self.act_to_impl, f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@ -283,10 +273,7 @@ class CPUFusedMOE:
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
|
||||
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
|
||||
if activation == "swigluoai":
|
||||
gate_up = swigluoai_and_mul(gate_up)
|
||||
else:
|
||||
gate_up = silu_and_mul(gate_up)
|
||||
gate_up = self.act_to_impl[activation].forward_native(gate_up)
|
||||
expert_out = layer.down_linear[i](gate_up)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user