mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:15:01 +08:00
[CPU]Improve cpu fused moe perf (#27244)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
parent
59a50afa08
commit
c757a15f0f
@ -5,6 +5,7 @@ from collections.abc import Callable
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
|
||||||
|
|
||||||
@ -237,7 +238,43 @@ class SGLFusedMOE:
|
|||||||
|
|
||||||
class CPUFusedMOE:
|
class CPUFusedMOE:
|
||||||
def __init__(self, layer: torch.nn.Module) -> None:
|
def __init__(self, layer: torch.nn.Module) -> None:
|
||||||
pass
|
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
|
||||||
|
|
||||||
|
num_experts = layer.w13_weight.size(0)
|
||||||
|
has_w13_bias = hasattr(layer, "w13_bias")
|
||||||
|
has_w2_bias = hasattr(layer, "w2_bias")
|
||||||
|
|
||||||
|
layer.gate_up_linear = []
|
||||||
|
layer.down_linear = []
|
||||||
|
|
||||||
|
for i in range(num_experts):
|
||||||
|
layer_w13_weight = layer.w13_weight[i]
|
||||||
|
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
|
||||||
|
layer_w2_weight = layer.w2_weight[i]
|
||||||
|
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
|
||||||
|
if use_onednn_mm:
|
||||||
|
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
|
||||||
|
layer.gate_up_linear.append(
|
||||||
|
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
|
||||||
|
handle, x, bias
|
||||||
|
)
|
||||||
|
)
|
||||||
|
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
|
||||||
|
layer.down_linear.append(
|
||||||
|
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
|
||||||
|
handle, x, bias
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer.gate_up_linear.append(
|
||||||
|
lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b)
|
||||||
|
)
|
||||||
|
layer.down_linear.append(
|
||||||
|
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
|
||||||
|
)
|
||||||
|
if use_onednn_mm: # remove weight
|
||||||
|
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -287,8 +324,6 @@ class CPUFusedMOE:
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
has_w13_bias = hasattr(layer, "w13_bias")
|
|
||||||
has_w2_bias = hasattr(layer, "w2_bias")
|
|
||||||
|
|
||||||
for i, num_tokens in enumerate(tokens_per_expert):
|
for i, num_tokens in enumerate(tokens_per_expert):
|
||||||
end_idx = start_idx + num_tokens
|
end_idx = start_idx + num_tokens
|
||||||
@ -296,19 +331,12 @@ class CPUFusedMOE:
|
|||||||
continue
|
continue
|
||||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||||
|
|
||||||
layer_w13_weight = layer.w13_weight[i]
|
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
|
||||||
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
|
|
||||||
layer_w2_weight = layer.w2_weight[i]
|
|
||||||
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
|
|
||||||
|
|
||||||
gate_up = F.linear(
|
|
||||||
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
|
|
||||||
)
|
|
||||||
if activation == "swigluoai":
|
if activation == "swigluoai":
|
||||||
gate_up = swigluoai_and_mul(gate_up)
|
gate_up = swigluoai_and_mul(gate_up)
|
||||||
else:
|
else:
|
||||||
gate_up = silu_and_mul(gate_up)
|
gate_up = silu_and_mul(gate_up)
|
||||||
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
|
expert_out = layer.down_linear[i](gate_up)
|
||||||
outputs.append(expert_out)
|
outputs.append(expert_out)
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user