[DeepSeek v3.2] Remove unnecessary syncwarps (#31047)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-24 00:33:30 -05:00 committed by GitHub
parent dabff12ed3
commit 369f47aa0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -451,9 +451,6 @@ __global__ void indexer_k_quant_and_cache_kernel(
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
}
#ifndef USE_ROCM
__syncwarp();
#endif
// Reduced amax
for (int mask = 16; mask > 0; mask /= 2) {
@ -463,9 +460,7 @@ __global__ void indexer_k_quant_and_cache_kernel(
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
}
#ifndef USE_ROCM
__syncwarp();
#endif
#if defined(__gfx942__)
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else