diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 4c85e6c5daa20..545eaf5bb40cf 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -536,7 +536,9 @@ __global__ void indexer_k_quant_and_cache_kernel( for (int i = 0; i < VEC_SIZE; i++) { amax = fmaxf(amax, fabsf(float(k_val_ptr[i]))); } +#ifndef USE_ROCM __syncwarp(); +#endif // Reduced amax for (int mask = 16; mask > 0; mask /= 2) { @@ -546,7 +548,9 @@ __global__ void indexer_k_quant_and_cache_kernel( amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); #endif } +#ifndef USE_ROCM __syncwarp(); +#endif float scale = fmaxf(amax, 1e-4) / 448.0f; if (use_ue8m0) { scale = exp2f(ceilf(log2f(scale))); @@ -1167,4 +1171,4 @@ void indexer_k_quant_and_cache( DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); -} \ No newline at end of file +}