mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 16:49:08 +08:00
[Bugfix][DSV32] Fix overflow in topk. (#30754)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f5f51e5931
commit
eaa82a709a
@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
|||||||
int rowEnd = rowEnds[rowIdx];
|
int rowEnd = rowEnds[rowIdx];
|
||||||
|
|
||||||
// Local pointers to this block
|
// Local pointers to this block
|
||||||
outIndices += rowIdx * topK;
|
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||||
logits += rowIdx * stride0;
|
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||||
|
|
||||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||||
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||||
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
|||||||
|
|
||||||
// Local pointers to this block
|
// Local pointers to this block
|
||||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||||
outIndices += rowIdx * topK;
|
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||||
} else if constexpr (multipleBlocksPerRow) {
|
} else if constexpr (multipleBlocksPerRow) {
|
||||||
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||||
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||||
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||||
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
outIndices +=
|
||||||
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||||
|
outLogits +=
|
||||||
|
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
|
||||||
} else if constexpr (mergeBlocks) {
|
} else if constexpr (mergeBlocks) {
|
||||||
rowEnd = numBlocksToMerge * topK;
|
rowEnd = numBlocksToMerge * topK;
|
||||||
indices += rowIdx * numBlocksToMerge * topK;
|
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
|
||||||
outIndices += rowIdx * topK;
|
outIndices += static_cast<int64_t>(rowIdx) * topK;
|
||||||
}
|
}
|
||||||
logits += rowIdx * stride0;
|
logits += static_cast<int64_t>(rowIdx) * stride0;
|
||||||
|
|
||||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||||
multipleBlocksPerRow, mergeBlocks>(
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user