[Bugfix][DSV32] Fix overflow in topk. (#30754)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Daniel Cámpora 2025-12-16 23:21:17 +01:00 committed by GitHub
parent f5f51e5931
commit eaa82a709a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
int rowEnd = rowEnds[rowIdx];
// Local pointers to this block
outIndices += rowIdx * topK;
logits += rowIdx * stride0;
outIndices += static_cast<int64_t>(rowIdx) * topK;
logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK;
outIndices += static_cast<int64_t>(rowIdx) * topK;
} else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outIndices +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
outLogits +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK;
outIndices += rowIdx * topK;
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
outIndices += static_cast<int64_t>(rowIdx) * topK;
}
logits += rowIdx * stride0;
logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>(