mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[Bugfix] fix moe marlin topk_weight loading (#18080)
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6685890d11
commit
d4154c35a2
@ -473,15 +473,15 @@ __global__ void Marlin(
|
||||
if (mul_topk_weights) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int idx = tid4 * 4 + i;
|
||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
|
||||
global_scale,
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
|
||||
sh_block_topk_weights[idx] = __hmul2(
|
||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
} else {
|
||||
sh_block_topk_weights[tid4 * 4 + i] =
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
||||
sh_block_topk_weights[idx] = Dtype::num2num2(
|
||||
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user