[CPU]Improve dynamic 4bit moe performance (#27240)

Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
xiangze-arm 2025-11-04 14:33:23 +08:00 committed by GitHub
parent 7e4be74104
commit f32cbc9a0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
c10::InferenceMode guard;
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}
const int64_t start = offsets[e];
auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}
// Store per-expert result
y_list[e] = y;
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});
// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);