mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 10:14:27 +08:00
[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE (#20166)
Signed-off-by: Ming Yang <yming@meta.com>
This commit is contained in:
parent
baba0389f7
commit
c438183e99
@ -7,7 +7,7 @@
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
|
||||
@ -78,7 +78,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.uint32
|
||||
return torch.int32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
@ -100,7 +100,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert topk_ids.size(0) == num_tokens
|
||||
# assert expert_map is None, "NYI"
|
||||
assert expert_map is None, """with expert map, -1 id is used for
|
||||
non-local token; this causes error when casting ids to the
|
||||
topk_indices_dtype() uint32"""
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
@ -929,9 +929,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
return self.fused_experts(
|
||||
x,
|
||||
@ -939,13 +942,14 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token=per_act_token,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user