mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 22:17:03 +08:00
[Fix] enable swap_ab for pplx problem size computation (#22991)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
8a87cd27d9
commit
7f89ed248f
@ -161,6 +161,7 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
topk_ids.size(1));
|
topk_ids.size(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool SWAP_AB>
|
||||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||||
int32_t* problem_sizes1,
|
int32_t* problem_sizes1,
|
||||||
int32_t* problem_sizes2,
|
int32_t* problem_sizes2,
|
||||||
@ -168,14 +169,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
|
|||||||
const int padded_m, const int n,
|
const int padded_m, const int n,
|
||||||
const int k) {
|
const int k) {
|
||||||
int expert_idx = threadIdx.x;
|
int expert_idx = threadIdx.x;
|
||||||
|
|
||||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
|
||||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
if constexpr (!SWAP_AB) {
|
||||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||||
|
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||||
|
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||||
|
} else {
|
||||||
|
problem_sizes1[expert_idx * 3] = 2 * n;
|
||||||
|
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||||
|
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||||
|
problem_sizes2[expert_idx * 3] = k;
|
||||||
|
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||||
|
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||||
@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
|||||||
const int64_t n, const int64_t k) {
|
const int64_t n, const int64_t k) {
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||||
|
|
||||||
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
|
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
k);
|
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||||
|
k);
|
||||||
|
} else {
|
||||||
|
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||||
|
k);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user