diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index df47bb8dd1d7d..58dc402016881 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu( } auto Y_all = at::empty({offsets[E], H}, x_c.options()); - at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) { c10::InferenceMode guard; - for (int64_t e = e_begin; e < e_end; ++e) { - const int64_t te = counts[e]; - if (te == 0) { + for (int64_t e = 0; e < E; ++e) { + int64_t start = std::max(offsets[e], idx_begin); + int64_t end = std::min(offsets[e + 1], idx_end); + int64_t te = end - start; + if (te <= 0) { continue; } - const int64_t start = offsets[e]; - auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e);