mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
[Bug] Fix Compressed Tensor NVFP4 cutlass_fp4_group_mm illegal memory access (#21465)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
684174115d
commit
e8cb0d0495
@ -47,13 +47,12 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
|||||||
|
|
||||||
__global__ void compute_expert_offsets(
|
__global__ void compute_expert_offsets(
|
||||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||||
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
|
int32_t* atomic_buffer, const int num_experts, const bool swap_ab) {
|
||||||
int32_t tot_offset = 0;
|
int32_t tot_offset = 0;
|
||||||
expert_offsets[0] = 0;
|
expert_offsets[0] = 0;
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
atomic_buffer[i] = tot_offset;
|
atomic_buffer[i] = tot_offset;
|
||||||
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
|
tot_offset += swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
|
||||||
: problem_sizes1[i * 3 + 1];
|
|
||||||
expert_offsets[i + 1] = tot_offset;
|
expert_offsets[i + 1] = tot_offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -61,15 +60,14 @@ __global__ void compute_expert_offsets(
|
|||||||
__global__ void compute_expert_blockscale_offsets(
|
__global__ void compute_expert_blockscale_offsets(
|
||||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||||
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
|
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
|
||||||
const int topk_length) {
|
const bool swap_ab) {
|
||||||
int32_t tot_offset = 0;
|
int32_t tot_offset = 0;
|
||||||
int32_t tot_offset_round = 0;
|
int32_t tot_offset_round = 0;
|
||||||
expert_offsets[0] = 0;
|
expert_offsets[0] = 0;
|
||||||
blockscale_offsets[0] = 0;
|
blockscale_offsets[0] = 0;
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
|
int32_t cur_offset =
|
||||||
? problem_sizes1[i * 3]
|
swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
|
||||||
: problem_sizes1[i * 3 + 1];
|
|
||||||
atomic_buffer[i] = tot_offset;
|
atomic_buffer[i] = tot_offset;
|
||||||
tot_offset += cur_offset;
|
tot_offset += cur_offset;
|
||||||
expert_offsets[i + 1] = tot_offset;
|
expert_offsets[i + 1] = tot_offset;
|
||||||
@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
|
|
||||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
|
||||||
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
|
// Swap-AB should be disabled for FP4 path
|
||||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||||
|
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||||
|
|
||||||
|
if (may_swap_ab) {
|
||||||
|
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||||
k);
|
k);
|
||||||
} else {
|
} else {
|
||||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (blockscale_offsets.has_value()) {
|
if (blockscale_offsets.has_value()) {
|
||||||
|
// fp4 path
|
||||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||||
topk_ids.numel());
|
may_swap_ab);
|
||||||
} else {
|
} else {
|
||||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||||
topk_ids.numel());
|
may_swap_ab);
|
||||||
}
|
}
|
||||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user