[Bugfix] fix moe marlin topk_weight loading (#18080)

Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin 2025-05-14 14:31:57 +08:00 committed by GitHub
parent 6685890d11
commit d4154c35a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -473,15 +473,15 @@ __global__ void Marlin(
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
global_scale,
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
}
}