mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 02:31:20 +08:00
Merge 348de41b5253318b4a578d7d972511aebb50f7dc into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
ed2509b7b0
@ -1839,6 +1839,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
|
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
|
||||||
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
|
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Zero out Qlocal for lanes that don't load Q data to prevent
|
||||||
|
// uninitialized register values from contaminating wmma results
|
||||||
|
#pragma unroll
|
||||||
|
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
|
||||||
|
Qlocal[qkhe_depth] = {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// fetch Q in shared across warps and then write to registers
|
// fetch Q in shared across warps and then write to registers
|
||||||
@ -2608,6 +2615,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
|
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
|
||||||
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
|
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Zero out Qlocal for lanes that don't load Q data to prevent
|
||||||
|
// uninitialized register values from contaminating wmma results
|
||||||
|
#pragma unroll
|
||||||
|
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
|
||||||
|
Qlocal[qkhe_depth] = {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// fetch Q in shared across warps and then write to registers
|
// fetch Q in shared across warps and then write to registers
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user