diff --git a/csrc/sampler.cu b/csrc/sampler.cu index fc2154beff9e0..d458f8e4c1d02 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( int rowEnd = rowEnds[rowIdx]; // Local pointers to this block - outIndices += rowIdx * topK; - logits += rowIdx * stride0; + outIndices += static_cast(rowIdx) * topK; + logits += static_cast(rowIdx) * stride0; topKPerRowJob( nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); @@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( // Local pointers to this block if constexpr (!multipleBlocksPerRow && !mergeBlocks) { - outIndices += rowIdx * topK; + outIndices += static_cast(rowIdx) * topK; } else if constexpr (multipleBlocksPerRow) { const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; - outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; - outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; + outIndices += + static_cast(rowIdx) * gridDim.y * topK + blockIdx.y * topK; + outLogits += + static_cast(rowIdx) * gridDim.y * topK + blockIdx.y * topK; } else if constexpr (mergeBlocks) { rowEnd = numBlocksToMerge * topK; - indices += rowIdx * numBlocksToMerge * topK; - outIndices += rowIdx * topK; + indices += static_cast(rowIdx) * numBlocksToMerge * topK; + outIndices += static_cast(rowIdx) * topK; } - logits += rowIdx * stride0; + logits += static_cast(rowIdx) * stride0; topKPerRowJob(