[Bugfix][DSV32] Fix overflow in topk. (#30754)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Daniel Cámpora 2025-12-16 23:21:17 +01:00 committed by GitHub
parent f5f51e5931
commit eaa82a709a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
int rowEnd = rowEnds[rowIdx]; int rowEnd = rowEnds[rowIdx];
// Local pointers to this block // Local pointers to this block
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>( topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// Local pointers to this block // Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) { if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} else if constexpr (multipleBlocksPerRow) { } else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; outIndices +=
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
outLogits +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) { } else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK; rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK; indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} }
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort, topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>( multipleBlocksPerRow, mergeBlocks>(