[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE (#20166)

Signed-off-by: Ming Yang <yming@meta.com>
This commit is contained in:
Ming Yang 2025-07-08 16:10:57 -07:00 committed by GitHub
parent baba0389f7
commit c438183e99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 11 deletions

View File

@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),

View File

@ -78,7 +78,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.uint32
return torch.int32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
@ -100,7 +100,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hidden_dim = a1.size(-1) # K
assert topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
assert expert_map is None, """with expert map, -1 id is used for
non-local token; this causes error when casting ids to the
topk_indices_dtype() uint32"""
# Is this always going to be a1.device?
device = a1.device

View File

@ -929,9 +929,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
e_score_correction_bias=e_score_correction_bias)
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
return self.fused_experts(
x,
@ -939,13 +942,14 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)