mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[CPU]Improve dynamic 4bit moe performance (#27240)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
parent
7e4be74104
commit
f32cbc9a0c
@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||
|
||||
// Per-expert outputs filled in parallel
|
||||
std::vector<torch::Tensor> y_list(E);
|
||||
y_list.resize(E);
|
||||
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
|
||||
if (apply_router_weight_on_input) {
|
||||
X_all = X_all.mul(expert_gates.unsqueeze(1));
|
||||
}
|
||||
auto Y_all = at::empty({offsets[E], H}, x_c.options());
|
||||
|
||||
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||
c10::InferenceMode guard;
|
||||
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||
const int64_t te = counts[e];
|
||||
if (te == 0) {
|
||||
y_list[e] = at::empty({0, H}, x_c.options());
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t start = offsets[e];
|
||||
|
||||
auto sel_tokens =
|
||||
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
auto gates_e =
|
||||
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
|
||||
|
||||
if (apply_router_weight_on_input) {
|
||||
x_e = x_e.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||
@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
// W2
|
||||
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||
|
||||
if (!apply_router_weight_on_input) {
|
||||
y = y.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
|
||||
// Store per-expert result
|
||||
y_list[e] = y;
|
||||
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
|
||||
}
|
||||
});
|
||||
|
||||
// Concatenate all expert outputs to match expert_tokens order
|
||||
auto Y_all = at::cat(y_list, /*dim=*/0);
|
||||
if (!apply_router_weight_on_input) {
|
||||
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
|
||||
}
|
||||
|
||||
auto out = at::zeros({T, H}, x.options());
|
||||
out =
|
||||
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user