diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index dedbe1b792f71..fdf0f51cd4a26 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -473,15 +473,15 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; if constexpr (w_type == vllm::kFE2M1f) { - sh_block_topk_weights[tid4 * 4 + i] = __hmul2( - global_scale, - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]))); + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); } else { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); } } }