mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 09:04:29 +08:00
[Fix] enable swap_ab for pplx problem size computation (#22991)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
8a87cd27d9
commit
7f89ed248f
@ -161,6 +161,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
@ -168,14 +169,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
const int padded_m, const int n,
|
||||
const int k) {
|
||||
int expert_idx = threadIdx.x;
|
||||
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_idx * 3] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = k;
|
||||
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
} else {
|
||||
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user