mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 09:55:01 +08:00
[Bugfix] Fix __syncwarp on ROCM (#25996)
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
a1825fe645
commit
febb688356
@ -536,7 +536,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
|||||||
for (int i = 0; i < VEC_SIZE; i++) {
|
for (int i = 0; i < VEC_SIZE; i++) {
|
||||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||||
}
|
}
|
||||||
|
#ifndef USE_ROCM
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
|
|
||||||
// Reduced amax
|
// Reduced amax
|
||||||
for (int mask = 16; mask > 0; mask /= 2) {
|
for (int mask = 16; mask > 0; mask /= 2) {
|
||||||
@ -546,7 +548,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
|||||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
#ifndef USE_ROCM
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||||
if (use_ue8m0) {
|
if (use_ue8m0) {
|
||||||
scale = exp2f(ceilf(log2f(scale)));
|
scale = exp2f(ceilf(log2f(scale)));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user