[DeepSeek v3.2] Make top-k work for any logit values. (#27568)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Daniel Cámpora 2025-12-08 15:55:58 +01:00 committed by GitHub
parent eb1051fb95
commit 184076c3fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 643 additions and 224 deletions

View File

@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,

View File

@ -44,41 +44,300 @@ __global__ void apply_repetition_penalties_kernel(
}
}
static inline __device__ uint16_t extractBinIdx(float x) {
union {
__half h;
uint16_t u16;
} tmp;
tmp.h = __float2half_rn(x);
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
return 511 - (tmp.u16 >> 7);
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
}
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
__device__ void topKPerRowJob(const float* logits, const int rowStart,
const int rowEnd, const int rowIdx,
int* outIndices, int stride0, int stride1) {
// The number of elements per thread for the final top-k sort.
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
// The class to sort the elements during the final top-k sort.
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumTopKItemsPerThread, int>;
template <int step>
static inline __device__ uint32_t extractBinIdx(float x) {
if constexpr (step == 0) {
__half hx = __float2half(x);
uint16_t bits = __half_as_ushort(hx);
bits = (bits & 0x8000) ? bits : ~bits & 0x7fff;
return bits >> 5;
} else {
uint32_t bits = __float_as_uint(x);
bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
if constexpr (step == 1) {
return bits >> 21;
} else if constexpr (step == 2) {
return (bits >> 10) & 0x7ff;
} else if constexpr (step == 3) {
return bits & 0x3ff;
}
}
}
template <int shift>
static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
if constexpr (shift == 0) {
return true;
}
uint32_t bits = __float_as_uint(x);
bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
return (bits ^ pattern) >> shift == 0;
}
/**
* Map a Func over the input data, using vectorized load instructions if
* possible.
*
* @tparam T element type
* @tparam IdxT indexing type
* @tparam Func void (T x, IdxT idx)
*
* @param thread_rank rank of the calling thread among all participating threads
* @param num_threads number of the threads that participate in processing
* @param in the input data
* @param len the number of elements to read
* @param f the lambda taking two arguments (T x, IdxT idx)
*/
template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
const T* in, idxT len, Func f) {
constexpr int WARP_SIZE = 32;
using WideT = float4;
if constexpr (sizeof(T) >= sizeof(WideT)) {
for (idxT i = thread_rank; i < len; i += num_threads) {
f(in[i], i);
}
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
// TODO: it's UB
union {
WideT scalar;
T array[items_per_scalar];
} wide;
int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
? ((sizeof(WideT) - reinterpret_cast<size_t>(in) % sizeof(WideT)) /
sizeof(T))
: 0;
if (skip_cnt > len) {
skip_cnt = len;
}
const WideT* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
for (idxT i = thread_rank; i < len_cast; i += num_threads) {
wide.scalar = in_cast[i];
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j);
}
}
static_assert(WARP_SIZE >= items_per_scalar);
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt
// no need to use loop
if (thread_rank < skip_cnt) {
f(in[thread_rank], thread_rank);
}
// because len_cast = (len - skip_cnt) / items_per_scalar,
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
// and so
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
// WARP_SIZE no need to use loop
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
if (remain_i < len) {
f(in[remain_i], remain_i);
}
}
}
template <int step, int kNumThreadsPerBlock, int kNumBins, int kNumFinalItems,
bool multipleBlocksPerRow, bool mergeBlocks, typename SmemFinalType,
typename SmemOutputType>
__device__ bool processHistogramStep(
const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern,
int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx,
int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues,
SmemFinalType& smemFinal, int stride1, int rowStart, int topK) {
// Clear the histogram.
#pragma unroll
for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) {
smemFinal.histo.data[idx] = 0;
}
// Make sure the histogram is ready.
__syncthreads();
// Update pattern
constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10;
if constexpr (step == 2) {
logitPattern = static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
<< patternShift;
} else if constexpr (step == 3) {
logitPattern |= static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
<< patternShift;
}
auto distributeToBins = [&](float logit, int /* idx */ = 0) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
atomicAdd(&smemFinal.histo.data[binIdx], 1);
}
};
// Distribute the elements to the histogram bins.
if (stride1 == 1) {
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
rowEnd - rowStart, distributeToBins);
} else {
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
idx += kNumThreadsPerBlock) {
float logit = logits[idx * stride1];
distributeToBins(logit, idx);
}
}
// Make sure the histogram is ready.
__syncthreads();
// Reads the value of the starting position in the smemOutput array
int lastValue = smemFoundTopKValues[0];
for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) {
// Read the values from SMEM.
int idx = threadIdx.x + kNumThreadsPerBlock * round;
int binCount{0};
binCount = smemFinal.histo.data[idx];
// Make sure each thread has read its value.
__syncthreads();
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
prefixSum += lastValue;
totalSum += lastValue;
smemFinal.histo.data[idx] = prefixSum;
// Make sure the data is in shared memory.
__syncthreads();
// Find the last valid bin.
bool foundThreshold = false;
if (prefixSum < topK) {
int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1
? totalSum
: smemFinal.histo.data[idx + 1];
if (nextPrefixSum >= topK) {
smemThresholdBinIdx[0] = idx;
smemFinalBinSize[0] = nextPrefixSum - prefixSum;
foundThreshold = true;
}
}
// Early exit: if any thread found the threshold, we can skip remaining
// rounds
if (__syncthreads_or(foundThreshold)) {
break;
}
lastValue = totalSum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The threshold bin.
thresholdBinIdx = smemThresholdBinIdx[0];
auto processBins = [&](float logit, int idx) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
if (binIdx < thresholdBinIdx) {
// The element is part of the top-k selection
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
if constexpr (mergeBlocks) {
smemOutput[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemOutput[dstIdx] = idx + rowStart;
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
} else {
smemOutput[dstIdx] = idx;
}
}
if constexpr (step < 3) {
// Only fill the final items for sorting if the threshold bin fits
if (binIdx == thresholdBinIdx &&
smemFinalBinSize[0] <= kNumFinalItems) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
smemFinal.items.logits[dstIdx] = logit;
if constexpr (mergeBlocks) {
smemFinal.items.indices[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemFinal.items.indices[dstIdx] = idx + rowStart;
} else {
smemFinal.items.indices[dstIdx] = idx;
}
}
} else {
if (binIdx == thresholdBinIdx) {
// The elements in the threshold bin share the same 32 bits at step 3
int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1);
if (dstIdx < topK) {
if constexpr (mergeBlocks) {
smemOutput[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemOutput[dstIdx] = idx + rowStart;
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
} else {
smemOutput[dstIdx] = idx;
}
}
}
}
}
};
if (stride1 == 1) {
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
rowEnd - rowStart, processBins);
} else {
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
idx += kNumThreadsPerBlock) {
float logit = logits[idx * stride1];
processBins(logit, idx);
}
}
// Make sure the elements are in shared memory.
__syncthreads();
// Check if we should continue to next step
return smemFinalBinSize[0] > kNumFinalItems;
}
// Follows half - 11 - 11 - 10 bit iterations
template <int kNumThreadsPerBlock, int kNumBins, bool useRadixSort,
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
static __device__ void topKPerRowJob(const int* indices, const float* logits,
int rowStart, int rowEnd, int* outIndices,
float* outLogits, int stride1, int topK) {
// The number of slots for the final pass.
static constexpr int kNumFinalItems = 3072;
static constexpr int kNumFinalItems = 2048;
// The number of elements per thread for the final sort.
static constexpr int kNumFinalItemsPerThread =
kNumFinalItems / kNumThreadsPerBlock;
// The class to sort the elements during the final pass.
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
using FinalSortTempStorage =
std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>;
// The class to compute the inclusive prefix-sum over the histogram.
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
// Shared memory to compute the block scan.
__shared__ typename Scan::TempStorage smemScan;
// The structure to store the final items (for the final pass).
struct FinalItems {
// Shared memory to store the indices for the final pass.
@ -87,200 +346,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart,
float logits[kNumFinalItems];
};
struct Histogram {
typename Scan::TempStorage scan;
int data[kNumBins];
};
// Shared memory to compute the block sort.
__shared__ union {
FinalItems items;
typename FinalSort::TempStorage finalSort;
typename TopKSort::TempStorage topKSort;
FinalSortTempStorage finalSort;
Histogram histo;
} smemFinal;
// Shared memory to store the histogram.
__shared__ int smemHistogram[kNumBins];
// Shared memory to store the selected indices.
__shared__ int smemIndices[kTopK];
// If we are processing using multiple blocks, we need to store the logits and
// indices.
extern __shared__ int32_t smemOutput[];
// Shared memory to store the threshold bin.
__shared__ int smemThresholdBinIdx[1];
// Shared memory counter to register the candidates for the final phase.
__shared__ int smemFinalDstIdx[1];
// Shared memory to determine if the threshold bin fits in the final items.
__shared__ int smemFinalBinSize[1];
// Shared memory to keep track of the top-k values found so far by the
// previous iterations
__shared__ int smemFoundTopKValues[1];
// The length of the row.
int rowLen = rowEnd - rowStart;
// Shortcut if the length of the row is smaller than Top-K. Indices are not
// sorted by their corresponding logit.
if (rowLen <= kTopK) {
if (rowLen <= topK) {
for (int rowIt = threadIdx.x; rowIt < rowLen;
rowIt += kNumThreadsPerBlock) {
int idx = rowStart + rowIt;
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
if constexpr (multipleBlocksPerRow) {
outIndices[rowIt] = rowIt + rowStart;
outLogits[rowIt] = logits[rowIt + rowStart];
} else {
outIndices[rowIt] = rowIt;
}
}
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
for (int rowIt = rowLen + threadIdx.x; rowIt < topK;
rowIt += kNumThreadsPerBlock) {
outIndices[rowIdx * kTopK + rowIt] = -1;
outIndices[rowIt] = -1;
if constexpr (multipleBlocksPerRow) {
outLogits[rowIt] = -FLT_MAX;
}
}
return;
}
// Clear the histogram.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = 0;
}
// Make sure the histogram is ready.
__syncthreads();
// Fetch elements one-by-one.
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
atomicAdd(&smemHistogram[idx], 1);
}
// Make sure the histogram is ready.
__syncthreads();
// Read the values from SMEM.
int binCount{0};
if (threadIdx.x < kNumBins) {
binCount = smemHistogram[threadIdx.x];
}
// Make sure each thread has read its value.
__syncthreads();
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = prefixSum;
}
// Make sure the data is in shared memory.
__syncthreads();
// Find the last valid bin.
if (threadIdx.x < kNumBins) {
int nextPrefixSum =
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
smemThresholdBinIdx[0] = threadIdx.x;
}
}
// Clear the counter to store the items for the final phase.
// Initialize values
if (threadIdx.x == 0) {
smemFinalDstIdx[0] = 0;
smemFoundTopKValues[0] = 0;
}
__syncthreads();
int thresholdBinIdx = -1;
uint32_t logitPattern = 0;
// Step 0: Process first 11 bits of half representation
bool continueToNextStep =
processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
if (continueToNextStep) {
// Step 1: Process next 11 bits
continueToNextStep =
processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
}
// Make sure the data is in shared memory.
__syncthreads();
if (continueToNextStep) {
// Step 2: Process next 11 bits
continueToNextStep =
processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
}
// The threshold bin.
int thresholdBinIdx = smemThresholdBinIdx[0];
if (continueToNextStep) {
// Step 3: Process last 10 bits
processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
}
// Fetch elements one-by-one and populate the shared memory buffers.
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
float logit = logits[rowIdx * stride0 + rowIt * stride1];
uint16_t idx = extractBinIdx(logit);
if (idx < thresholdBinIdx) {
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
smemIndices[dstIdx] = rowIt;
} else if (idx == thresholdBinIdx) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
if (dstIdx < kNumFinalItems) {
smemFinal.items.logits[dstIdx] = logit;
smemFinal.items.indices[dstIdx] = rowIt;
if (!continueToNextStep) {
// The histogram did not proceed to the final 10 bits, therefore we need to
// sort the final items The logits of the elements to be sorted in the final
// pass.
if constexpr (useRadixSort) {
// Sorting with radix sort
float finalLogits[kNumFinalItemsPerThread];
// The indices of the elements to be sorted in the final pass.
int finalIndices[kNumFinalItemsPerThread];
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
finalLogits[ii] = -FLT_MAX;
}
// Read the elements from SMEM.
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
if (srcIdx < smemFinalDstIdx[0]) {
finalLogits[ii] = smemFinal.items.logits[srcIdx];
finalIndices[ii] = smemFinal.items.indices[srcIdx];
}
}
// Make sure the shared memory has been read.
__syncthreads();
// Sort the elements.
FinalSort(smemFinal.finalSort)
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
// Copy the data back to the shared memory storage.
int baseIdx = smemFoundTopKValues[0];
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
int dstIdx = baseIdx + srcIdx;
if (dstIdx < topK) {
smemOutput[dstIdx] = finalIndices[ii];
if constexpr (multipleBlocksPerRow) {
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] =
finalLogits[ii];
}
}
}
} else {
// Sorting with insertion sort
auto baseIdx = smemFoundTopKValues[0];
for (int i = threadIdx.x; i < smemFinalDstIdx[0];
i += kNumThreadsPerBlock) {
int outIndex = 0;
auto logit = smemFinal.items.logits[i];
for (int j = 0; j < smemFinalDstIdx[0]; j++) {
auto otherLogit = smemFinal.items.logits[j];
if (logit < otherLogit || (logit == otherLogit && i < j)) {
outIndex++;
}
}
// Store if outIndex is in bounds
if (outIndex + baseIdx < topK) {
smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i];
if constexpr (multipleBlocksPerRow) {
reinterpret_cast<float*>(smemOutput + topK)[outIndex + baseIdx] =
smemFinal.items.logits[i];
}
}
}
}
__syncthreads();
}
// Store to global memory.
for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) {
if constexpr (multipleBlocksPerRow) {
outIndices[i] = smemOutput[i];
outLogits[i] = reinterpret_cast<float*>(smemOutput + topK)[i];
} else {
if (stride1 == 1) {
// stride1 == 1 will use vectorized_process, which indexes already skip
// the rowStart.
outIndices[i] = smemOutput[i];
} else {
outIndices[i] = smemOutput[i] - rowStart;
}
}
}
// Make sure the elements are in shared memory.
__syncthreads();
// The logits of the elements to be sorted in the final pass.
float finalLogits[kNumFinalItemsPerThread];
// The indices of the elements to be sorted in the final pass.
int finalIndices[kNumFinalItemsPerThread];
// Init.
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
finalLogits[ii] = -FLT_MAX;
}
// Read the elements from SMEM.
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
if (srcIdx < smemFinalDstIdx[0]) {
finalLogits[ii] = smemFinal.items.logits[srcIdx];
finalIndices[ii] = smemFinal.items.indices[srcIdx];
}
}
// Make sure the shared memory has been read.
__syncthreads();
// Sort the elements.
FinalSort(smemFinal.finalSort)
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
// Copy the data back to the shared memory storage.
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
int dstIdx = baseIdx + srcIdx;
if (dstIdx < kTopK) {
smemIndices[dstIdx] = finalIndices[ii];
}
}
// Make sure the data is in shared memory.
__syncthreads();
// Store to global memory.
#pragma unroll
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
outIndices[offset] =
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
}
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
const int* rowEnds, int* outIndices,
int stride0, int stride1) {
template <int kNumThreadsPerBlock, bool useRadixSort>
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
const float* logits, const int* rowStarts, const int* rowEnds,
int* outIndices, int stride0, int stride1, const int topK,
const int offsetIndex) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
static constexpr int kNumBins = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
int rowIdx = blockIdx.x + offsetIndex;
// The range of logits within the row.
int rowStart = rowStarts[rowIdx];
int rowEnd = rowEnds[rowIdx];
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
// Local pointers to this block
outIndices += rowIdx * topK;
logits += rowIdx * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
int* outIndices, int stride0,
int stride1, int next_n) {
template <int kNumThreadsPerBlock, bool useRadixSort,
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
const float* logits, const int* seqLens, int* outIndices, int stride0,
int stride1, const int topK, int next_n, float* outLogits = nullptr,
const int numBlocksToMerge = 0, const int* indices = nullptr) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
static constexpr int kNumBins = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
@ -290,8 +574,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK;
} else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK;
outIndices += rowIdx * topK;
}
logits += rowIdx * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);
}
} // namespace vllm
@ -339,28 +640,84 @@ void apply_repetition_penalties_(
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {
// Compute the results on the device.
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kSplitWorkThreshold = 200 * 1000;
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto numColumns = logits.size(1);
if (numColumns < kSortingAlgorithmThreshold) {
// Use insertion sort
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
} else if (numColumns < kSplitWorkThreshold) {
// From this threshold, use radix sort instead
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
} else {
// Long sequences are run in two steps
constexpr auto multipleBlocksPerRowConfig = 10;
const auto outIndicesAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kInt32).device(logits.device()));
const auto outLogitsAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kFloat).device(logits.device()));
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
2 * topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
constexpr int kNumThreadsPerBlockMerge = 1024;
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
}
}
void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRowDecode<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(next_n));
}
int numInsertionBlocks =
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), 0);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {
// Compute the results on the device.
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRow<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1));
if (numRows > kSortingAlgorithmThreshold) {
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), kSortingAlgorithmThreshold);
}
}

View File

@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation
ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
"int stride1, int topK) -> ()");
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, int numRows, "
"int stride0, int stride1) -> ()");
"Tensor seq_lens, Tensor! indices, "
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
// Layernorm-quant

View File

@ -9,23 +9,45 @@ from vllm.platforms import current_platform
# Test parameters
NUM_ROWS = [1, 32, 2050]
TOP_K_VALUES = [2048]
BATCH_SIZE = [1, 2, 4, 2048, 4096]
NEXT_N = [1, 2, 4, 8]
TOP_K_VALUES = [2048, 3000]
BATCH_SIZE = [1, 2, 2048]
NEXT_N = [1, 8]
DATA_GENERATION = ["random", "10LSBits"]
def create_random_logits(
row_starts: torch.Tensor,
row_ends: torch.Tensor,
vocab_size: int,
dtype: torch.dtype,
seed: int,
data_generation: str,
) -> torch.Tensor:
"""Create random logits tensor for testing."""
torch.manual_seed(seed)
np.random.seed(seed)
# Generate logits with some structure to make testing more meaningful
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda")
if data_generation == "random":
logits = torch.randn(
row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
)
elif data_generation == "10LSBits":
top_22_bits_mask = 0xFFFFFC00
last_10_bits_mask = 0x000003FF
fixed_top_22_bits = 0x3F900000
# Generate random bits for the last 10 bits
random_bottom_bits = torch.randint(
0,
2**10,
(row_starts.shape[0], max(row_ends)),
dtype=torch.int32,
device="cuda",
)
# Combine: fixed top 22 bits with random last 10 bits
logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
random_bottom_bits & last_10_bits_mask
)
logits = logits_bits.view(dtype)
for i, end in enumerate(row_ends):
logits[i, end:] = float("-inf")
return logits
@ -113,13 +135,13 @@ def test_top_k_per_row(
# Create test data
vocab_size = 20000
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
row_starts,
row_ends,
@ -127,6 +149,7 @@ def test_top_k_per_row(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
# Run reference implementation
@ -139,27 +162,23 @@ def test_top_k_per_row(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_prefill results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
def _run_top_k_per_row_decode_test(
top_k: int,
batch_size: int,
next_n: int,
vocab_size: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
Helper function to run top_k_per_row_decode test with given parameters.
"""
torch.set_default_device("cuda:0")
# Create test data
num_rows = batch_size * next_n
vocab_size = 20000
seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
)
@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, data_generation
)
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
torch.cuda.synchronize()
@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_decode results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int,
batch_size: int,
next_n: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
"""
vocab_size = 20000
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode_large_vocab_size() -> None:
"""
Test top_k_per_row_decode with large vocabulary size.
"""
top_k = 2048
batch_size = 2
next_n = 2
vocab_size = 300000
data_generation = "random"
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)

View File

@ -684,11 +684,10 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
@ -696,6 +695,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
@ -738,7 +738,6 @@ def sparse_attn_indexer(
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
@ -749,6 +748,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack