From d4154c35a241077d27418940f2553003c58dd903 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 14 May 2025 14:31:57 +0800 Subject: [PATCH] [Bugfix] fix moe marlin `topk_weight` loading (#18080) Co-authored-by: mgoin --- csrc/moe/marlin_moe_wna16/marlin_template.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index dedbe1b792f7..fdf0f51cd4a2 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -473,15 +473,15 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; if constexpr (w_type == vllm::kFE2M1f) { - sh_block_topk_weights[tid4 * 4 + i] = __hmul2( - global_scale, - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]))); + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); } else { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); } } }