mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 04:25:17 +08:00
[Bugfix][DSV32] Fix overflow in topk. (#30754)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f5f51e5931
commit
eaa82a709a
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||
int rowEnd = rowEnds[rowIdx];
|
||||
|
||||
// Local pointers to this block
|
||||
outIndices += rowIdx * topK;
|
||||
logits += rowIdx * stride0;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
outIndices += rowIdx * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
} else if constexpr (multipleBlocksPerRow) {
|
||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||
outIndices +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
outLogits +=
|
||||
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||
} else if constexpr (mergeBlocks) {
|
||||
rowEnd = numBlocksToMerge * topK;
|
||||
indices += rowIdx * numBlocksToMerge * topK;
|
||||
outIndices += rowIdx * topK;
|
||||
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||
}
|
||||
logits += rowIdx * stride0;
|
||||
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||
|
||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||
multipleBlocksPerRow, mergeBlocks>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user