mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[Bugfix] fix moe marlin topk_weight loading (#18080)
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6685890d11
commit
d4154c35a2
@ -473,15 +473,15 @@ __global__ void Marlin(
|
|||||||
if (mul_topk_weights) {
|
if (mul_topk_weights) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
|
int idx = tid4 * 4 + i;
|
||||||
|
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
if constexpr (w_type == vllm::kFE2M1f) {
|
||||||
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
|
sh_block_topk_weights[idx] = __hmul2(
|
||||||
global_scale,
|
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||||
Dtype::num2num2(Dtype::float2num(
|
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
|
|
||||||
} else {
|
} else {
|
||||||
sh_block_topk_weights[tid4 * 4 + i] =
|
sh_block_topk_weights[idx] = Dtype::num2num2(
|
||||||
Dtype::num2num2(Dtype::float2num(
|
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
|
||||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user