mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 07:54:26 +08:00
[DeepSeek v3.2] Make top-k work for any logit values. (#27568)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
eb1051fb95
commit
184076c3fe
13
csrc/ops.h
13
csrc/ops.h
@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
|
|||||||
const torch::Tensor& output_mask,
|
const torch::Tensor& output_mask,
|
||||||
const torch::Tensor& repetition_penalties);
|
const torch::Tensor& repetition_penalties);
|
||||||
|
|
||||||
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
void top_k_per_row_prefill(const torch::Tensor& logits,
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::Tensor& rowStarts,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1);
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
|
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||||
|
int64_t topK);
|
||||||
|
|
||||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||||
const torch::Tensor& seq_lens, torch::Tensor& indices,
|
const torch::Tensor& seqLens, torch::Tensor& indices,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1);
|
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||||
|
int64_t topK);
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
|
|||||||
743
csrc/sampler.cu
743
csrc/sampler.cu
@ -44,41 +44,300 @@ __global__ void apply_repetition_penalties_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __device__ uint16_t extractBinIdx(float x) {
|
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
|
||||||
union {
|
uint32_t bits = __float_as_uint(x);
|
||||||
__half h;
|
return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
|
||||||
uint16_t u16;
|
|
||||||
} tmp;
|
|
||||||
tmp.h = __float2half_rn(x);
|
|
||||||
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
|
|
||||||
return 511 - (tmp.u16 >> 7);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
|
template <int step>
|
||||||
__device__ void topKPerRowJob(const float* logits, const int rowStart,
|
static inline __device__ uint32_t extractBinIdx(float x) {
|
||||||
const int rowEnd, const int rowIdx,
|
if constexpr (step == 0) {
|
||||||
int* outIndices, int stride0, int stride1) {
|
__half hx = __float2half(x);
|
||||||
// The number of elements per thread for the final top-k sort.
|
uint16_t bits = __half_as_ushort(hx);
|
||||||
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
|
bits = (bits & 0x8000) ? bits : ~bits & 0x7fff;
|
||||||
// The class to sort the elements during the final top-k sort.
|
return bits >> 5;
|
||||||
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
|
} else {
|
||||||
kNumTopKItemsPerThread, int>;
|
uint32_t bits = __float_as_uint(x);
|
||||||
|
bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
|
||||||
|
|
||||||
|
if constexpr (step == 1) {
|
||||||
|
return bits >> 21;
|
||||||
|
} else if constexpr (step == 2) {
|
||||||
|
return (bits >> 10) & 0x7ff;
|
||||||
|
} else if constexpr (step == 3) {
|
||||||
|
return bits & 0x3ff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int shift>
|
||||||
|
static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
|
||||||
|
if constexpr (shift == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
uint32_t bits = __float_as_uint(x);
|
||||||
|
bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
|
||||||
|
return (bits ^ pattern) >> shift == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map a Func over the input data, using vectorized load instructions if
|
||||||
|
* possible.
|
||||||
|
*
|
||||||
|
* @tparam T element type
|
||||||
|
* @tparam IdxT indexing type
|
||||||
|
* @tparam Func void (T x, IdxT idx)
|
||||||
|
*
|
||||||
|
* @param thread_rank rank of the calling thread among all participating threads
|
||||||
|
* @param num_threads number of the threads that participate in processing
|
||||||
|
* @param in the input data
|
||||||
|
* @param len the number of elements to read
|
||||||
|
* @param f the lambda taking two arguments (T x, IdxT idx)
|
||||||
|
*/
|
||||||
|
template <typename T, typename idxT, typename Func>
|
||||||
|
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
|
||||||
|
const T* in, idxT len, Func f) {
|
||||||
|
constexpr int WARP_SIZE = 32;
|
||||||
|
using WideT = float4;
|
||||||
|
if constexpr (sizeof(T) >= sizeof(WideT)) {
|
||||||
|
for (idxT i = thread_rank; i < len; i += num_threads) {
|
||||||
|
f(in[i], i);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(sizeof(WideT) % sizeof(T) == 0);
|
||||||
|
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
|
||||||
|
// TODO: it's UB
|
||||||
|
union {
|
||||||
|
WideT scalar;
|
||||||
|
T array[items_per_scalar];
|
||||||
|
} wide;
|
||||||
|
|
||||||
|
int skip_cnt =
|
||||||
|
(reinterpret_cast<size_t>(in) % sizeof(WideT))
|
||||||
|
? ((sizeof(WideT) - reinterpret_cast<size_t>(in) % sizeof(WideT)) /
|
||||||
|
sizeof(T))
|
||||||
|
: 0;
|
||||||
|
if (skip_cnt > len) {
|
||||||
|
skip_cnt = len;
|
||||||
|
}
|
||||||
|
const WideT* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
|
||||||
|
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
|
||||||
|
|
||||||
|
for (idxT i = thread_rank; i < len_cast; i += num_threads) {
|
||||||
|
wide.scalar = in_cast[i];
|
||||||
|
const idxT real_i = skip_cnt + i * items_per_scalar;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < items_per_scalar; ++j) {
|
||||||
|
f(wide.array[j], real_i + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(WARP_SIZE >= items_per_scalar);
|
||||||
|
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt
|
||||||
|
// no need to use loop
|
||||||
|
if (thread_rank < skip_cnt) {
|
||||||
|
f(in[thread_rank], thread_rank);
|
||||||
|
}
|
||||||
|
// because len_cast = (len - skip_cnt) / items_per_scalar,
|
||||||
|
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
|
||||||
|
// and so
|
||||||
|
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
|
||||||
|
// WARP_SIZE no need to use loop
|
||||||
|
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
|
||||||
|
if (remain_i < len) {
|
||||||
|
f(in[remain_i], remain_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int step, int kNumThreadsPerBlock, int kNumBins, int kNumFinalItems,
|
||||||
|
bool multipleBlocksPerRow, bool mergeBlocks, typename SmemFinalType,
|
||||||
|
typename SmemOutputType>
|
||||||
|
__device__ bool processHistogramStep(
|
||||||
|
const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern,
|
||||||
|
int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx,
|
||||||
|
int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues,
|
||||||
|
SmemFinalType& smemFinal, int stride1, int rowStart, int topK) {
|
||||||
|
// Clear the histogram.
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) {
|
||||||
|
smemFinal.histo.data[idx] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the histogram is ready.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Update pattern
|
||||||
|
constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10;
|
||||||
|
if constexpr (step == 2) {
|
||||||
|
logitPattern = static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
|
||||||
|
<< patternShift;
|
||||||
|
} else if constexpr (step == 3) {
|
||||||
|
logitPattern |= static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
|
||||||
|
<< patternShift;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto distributeToBins = [&](float logit, int /* idx */ = 0) {
|
||||||
|
if (isPartialMatch<patternShift>(logit, logitPattern)) {
|
||||||
|
uint32_t binIdx = extractBinIdx<step>(logit);
|
||||||
|
atomicAdd(&smemFinal.histo.data[binIdx], 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Distribute the elements to the histogram bins.
|
||||||
|
if (stride1 == 1) {
|
||||||
|
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
|
||||||
|
rowEnd - rowStart, distributeToBins);
|
||||||
|
} else {
|
||||||
|
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
|
||||||
|
idx += kNumThreadsPerBlock) {
|
||||||
|
float logit = logits[idx * stride1];
|
||||||
|
distributeToBins(logit, idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Make sure the histogram is ready.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Reads the value of the starting position in the smemOutput array
|
||||||
|
int lastValue = smemFoundTopKValues[0];
|
||||||
|
|
||||||
|
for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) {
|
||||||
|
// Read the values from SMEM.
|
||||||
|
int idx = threadIdx.x + kNumThreadsPerBlock * round;
|
||||||
|
int binCount{0};
|
||||||
|
binCount = smemFinal.histo.data[idx];
|
||||||
|
|
||||||
|
// Make sure each thread has read its value.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Compute the prefix sum.
|
||||||
|
int prefixSum{0}, totalSum{0};
|
||||||
|
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
|
||||||
|
Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum);
|
||||||
|
|
||||||
|
// Update the histogram with the prefix sums.
|
||||||
|
prefixSum += lastValue;
|
||||||
|
totalSum += lastValue;
|
||||||
|
smemFinal.histo.data[idx] = prefixSum;
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Find the last valid bin.
|
||||||
|
bool foundThreshold = false;
|
||||||
|
if (prefixSum < topK) {
|
||||||
|
int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1
|
||||||
|
? totalSum
|
||||||
|
: smemFinal.histo.data[idx + 1];
|
||||||
|
|
||||||
|
if (nextPrefixSum >= topK) {
|
||||||
|
smemThresholdBinIdx[0] = idx;
|
||||||
|
smemFinalBinSize[0] = nextPrefixSum - prefixSum;
|
||||||
|
foundThreshold = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Early exit: if any thread found the threshold, we can skip remaining
|
||||||
|
// rounds
|
||||||
|
if (__syncthreads_or(foundThreshold)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
lastValue = totalSum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The threshold bin.
|
||||||
|
thresholdBinIdx = smemThresholdBinIdx[0];
|
||||||
|
|
||||||
|
auto processBins = [&](float logit, int idx) {
|
||||||
|
if (isPartialMatch<patternShift>(logit, logitPattern)) {
|
||||||
|
uint32_t binIdx = extractBinIdx<step>(logit);
|
||||||
|
if (binIdx < thresholdBinIdx) {
|
||||||
|
// The element is part of the top-k selection
|
||||||
|
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
|
||||||
|
|
||||||
|
if constexpr (mergeBlocks) {
|
||||||
|
smemOutput[dstIdx] = indices[idx];
|
||||||
|
} else if constexpr (multipleBlocksPerRow) {
|
||||||
|
smemOutput[dstIdx] = idx + rowStart;
|
||||||
|
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
|
||||||
|
} else {
|
||||||
|
smemOutput[dstIdx] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (step < 3) {
|
||||||
|
// Only fill the final items for sorting if the threshold bin fits
|
||||||
|
if (binIdx == thresholdBinIdx &&
|
||||||
|
smemFinalBinSize[0] <= kNumFinalItems) {
|
||||||
|
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
|
||||||
|
smemFinal.items.logits[dstIdx] = logit;
|
||||||
|
if constexpr (mergeBlocks) {
|
||||||
|
smemFinal.items.indices[dstIdx] = indices[idx];
|
||||||
|
} else if constexpr (multipleBlocksPerRow) {
|
||||||
|
smemFinal.items.indices[dstIdx] = idx + rowStart;
|
||||||
|
} else {
|
||||||
|
smemFinal.items.indices[dstIdx] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (binIdx == thresholdBinIdx) {
|
||||||
|
// The elements in the threshold bin share the same 32 bits at step 3
|
||||||
|
int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1);
|
||||||
|
if (dstIdx < topK) {
|
||||||
|
if constexpr (mergeBlocks) {
|
||||||
|
smemOutput[dstIdx] = indices[idx];
|
||||||
|
} else if constexpr (multipleBlocksPerRow) {
|
||||||
|
smemOutput[dstIdx] = idx + rowStart;
|
||||||
|
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
|
||||||
|
} else {
|
||||||
|
smemOutput[dstIdx] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (stride1 == 1) {
|
||||||
|
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
|
||||||
|
rowEnd - rowStart, processBins);
|
||||||
|
} else {
|
||||||
|
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
|
||||||
|
idx += kNumThreadsPerBlock) {
|
||||||
|
float logit = logits[idx * stride1];
|
||||||
|
processBins(logit, idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the elements are in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Check if we should continue to next step
|
||||||
|
return smemFinalBinSize[0] > kNumFinalItems;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Follows half - 11 - 11 - 10 bit iterations
|
||||||
|
template <int kNumThreadsPerBlock, int kNumBins, bool useRadixSort,
|
||||||
|
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
|
||||||
|
static __device__ void topKPerRowJob(const int* indices, const float* logits,
|
||||||
|
int rowStart, int rowEnd, int* outIndices,
|
||||||
|
float* outLogits, int stride1, int topK) {
|
||||||
// The number of slots for the final pass.
|
// The number of slots for the final pass.
|
||||||
static constexpr int kNumFinalItems = 3072;
|
static constexpr int kNumFinalItems = 2048;
|
||||||
// The number of elements per thread for the final sort.
|
// The number of elements per thread for the final sort.
|
||||||
static constexpr int kNumFinalItemsPerThread =
|
static constexpr int kNumFinalItemsPerThread =
|
||||||
kNumFinalItems / kNumThreadsPerBlock;
|
kNumFinalItems / kNumThreadsPerBlock;
|
||||||
// The class to sort the elements during the final pass.
|
// The class to sort the elements during the final pass.
|
||||||
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
|
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
|
||||||
kNumFinalItemsPerThread, int>;
|
kNumFinalItemsPerThread, int>;
|
||||||
|
using FinalSortTempStorage =
|
||||||
|
std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>;
|
||||||
// The class to compute the inclusive prefix-sum over the histogram.
|
// The class to compute the inclusive prefix-sum over the histogram.
|
||||||
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
|
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
|
||||||
|
|
||||||
// Shared memory to compute the block scan.
|
|
||||||
__shared__ typename Scan::TempStorage smemScan;
|
|
||||||
|
|
||||||
// The structure to store the final items (for the final pass).
|
// The structure to store the final items (for the final pass).
|
||||||
struct FinalItems {
|
struct FinalItems {
|
||||||
// Shared memory to store the indices for the final pass.
|
// Shared memory to store the indices for the final pass.
|
||||||
@ -87,200 +346,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart,
|
|||||||
float logits[kNumFinalItems];
|
float logits[kNumFinalItems];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Histogram {
|
||||||
|
typename Scan::TempStorage scan;
|
||||||
|
int data[kNumBins];
|
||||||
|
};
|
||||||
|
|
||||||
// Shared memory to compute the block sort.
|
// Shared memory to compute the block sort.
|
||||||
__shared__ union {
|
__shared__ union {
|
||||||
FinalItems items;
|
FinalItems items;
|
||||||
typename FinalSort::TempStorage finalSort;
|
FinalSortTempStorage finalSort;
|
||||||
typename TopKSort::TempStorage topKSort;
|
Histogram histo;
|
||||||
} smemFinal;
|
} smemFinal;
|
||||||
|
|
||||||
// Shared memory to store the histogram.
|
|
||||||
__shared__ int smemHistogram[kNumBins];
|
|
||||||
// Shared memory to store the selected indices.
|
// Shared memory to store the selected indices.
|
||||||
__shared__ int smemIndices[kTopK];
|
// If we are processing using multiple blocks, we need to store the logits and
|
||||||
|
// indices.
|
||||||
|
extern __shared__ int32_t smemOutput[];
|
||||||
|
|
||||||
// Shared memory to store the threshold bin.
|
// Shared memory to store the threshold bin.
|
||||||
__shared__ int smemThresholdBinIdx[1];
|
__shared__ int smemThresholdBinIdx[1];
|
||||||
// Shared memory counter to register the candidates for the final phase.
|
// Shared memory counter to register the candidates for the final phase.
|
||||||
__shared__ int smemFinalDstIdx[1];
|
__shared__ int smemFinalDstIdx[1];
|
||||||
|
// Shared memory to determine if the threshold bin fits in the final items.
|
||||||
|
__shared__ int smemFinalBinSize[1];
|
||||||
|
// Shared memory to keep track of the top-k values found so far by the
|
||||||
|
// previous iterations
|
||||||
|
__shared__ int smemFoundTopKValues[1];
|
||||||
|
|
||||||
// The length of the row.
|
// The length of the row.
|
||||||
int rowLen = rowEnd - rowStart;
|
int rowLen = rowEnd - rowStart;
|
||||||
|
|
||||||
// Shortcut if the length of the row is smaller than Top-K. Indices are not
|
// Shortcut if the length of the row is smaller than Top-K. Indices are not
|
||||||
// sorted by their corresponding logit.
|
// sorted by their corresponding logit.
|
||||||
if (rowLen <= kTopK) {
|
if (rowLen <= topK) {
|
||||||
for (int rowIt = threadIdx.x; rowIt < rowLen;
|
for (int rowIt = threadIdx.x; rowIt < rowLen;
|
||||||
rowIt += kNumThreadsPerBlock) {
|
rowIt += kNumThreadsPerBlock) {
|
||||||
int idx = rowStart + rowIt;
|
if constexpr (multipleBlocksPerRow) {
|
||||||
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
|
outIndices[rowIt] = rowIt + rowStart;
|
||||||
|
outLogits[rowIt] = logits[rowIt + rowStart];
|
||||||
|
} else {
|
||||||
|
outIndices[rowIt] = rowIt;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
|
for (int rowIt = rowLen + threadIdx.x; rowIt < topK;
|
||||||
rowIt += kNumThreadsPerBlock) {
|
rowIt += kNumThreadsPerBlock) {
|
||||||
outIndices[rowIdx * kTopK + rowIt] = -1;
|
outIndices[rowIt] = -1;
|
||||||
|
if constexpr (multipleBlocksPerRow) {
|
||||||
|
outLogits[rowIt] = -FLT_MAX;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Initialize values
|
||||||
// Clear the histogram.
|
|
||||||
if (threadIdx.x < kNumBins) {
|
|
||||||
smemHistogram[threadIdx.x] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the histogram is ready.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Fetch elements one-by-one.
|
|
||||||
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
|
|
||||||
rowIt += kNumThreadsPerBlock) {
|
|
||||||
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
|
|
||||||
atomicAdd(&smemHistogram[idx], 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the histogram is ready.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Read the values from SMEM.
|
|
||||||
int binCount{0};
|
|
||||||
if (threadIdx.x < kNumBins) {
|
|
||||||
binCount = smemHistogram[threadIdx.x];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure each thread has read its value.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Compute the prefix sum.
|
|
||||||
int prefixSum{0}, totalSum{0};
|
|
||||||
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
|
|
||||||
|
|
||||||
// Update the histogram with the prefix sums.
|
|
||||||
if (threadIdx.x < kNumBins) {
|
|
||||||
smemHistogram[threadIdx.x] = prefixSum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the data is in shared memory.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Find the last valid bin.
|
|
||||||
if (threadIdx.x < kNumBins) {
|
|
||||||
int nextPrefixSum =
|
|
||||||
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
|
|
||||||
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
|
|
||||||
smemThresholdBinIdx[0] = threadIdx.x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the counter to store the items for the final phase.
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
smemFinalDstIdx[0] = 0;
|
smemFinalDstIdx[0] = 0;
|
||||||
|
smemFoundTopKValues[0] = 0;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
int thresholdBinIdx = -1;
|
||||||
|
uint32_t logitPattern = 0;
|
||||||
|
|
||||||
|
// Step 0: Process first 11 bits of half representation
|
||||||
|
bool continueToNextStep =
|
||||||
|
processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
|
||||||
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
|
||||||
|
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
|
||||||
|
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
|
||||||
|
|
||||||
|
if (continueToNextStep) {
|
||||||
|
// Step 1: Process next 11 bits
|
||||||
|
continueToNextStep =
|
||||||
|
processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
|
||||||
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
|
||||||
|
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
|
||||||
|
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the data is in shared memory.
|
if (continueToNextStep) {
|
||||||
__syncthreads();
|
// Step 2: Process next 11 bits
|
||||||
|
continueToNextStep =
|
||||||
|
processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
|
||||||
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
|
||||||
|
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
|
||||||
|
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
|
||||||
|
}
|
||||||
|
|
||||||
// The threshold bin.
|
if (continueToNextStep) {
|
||||||
int thresholdBinIdx = smemThresholdBinIdx[0];
|
// Step 3: Process last 10 bits
|
||||||
|
processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
|
||||||
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
|
||||||
|
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
|
||||||
|
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
|
||||||
|
}
|
||||||
|
|
||||||
// Fetch elements one-by-one and populate the shared memory buffers.
|
if (!continueToNextStep) {
|
||||||
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
|
// The histogram did not proceed to the final 10 bits, therefore we need to
|
||||||
rowIt += kNumThreadsPerBlock) {
|
// sort the final items The logits of the elements to be sorted in the final
|
||||||
float logit = logits[rowIdx * stride0 + rowIt * stride1];
|
// pass.
|
||||||
uint16_t idx = extractBinIdx(logit);
|
if constexpr (useRadixSort) {
|
||||||
if (idx < thresholdBinIdx) {
|
// Sorting with radix sort
|
||||||
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
|
float finalLogits[kNumFinalItemsPerThread];
|
||||||
smemIndices[dstIdx] = rowIt;
|
// The indices of the elements to be sorted in the final pass.
|
||||||
} else if (idx == thresholdBinIdx) {
|
int finalIndices[kNumFinalItemsPerThread];
|
||||||
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
|
|
||||||
if (dstIdx < kNumFinalItems) {
|
#pragma unroll
|
||||||
smemFinal.items.logits[dstIdx] = logit;
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
smemFinal.items.indices[dstIdx] = rowIt;
|
finalLogits[ii] = -FLT_MAX;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the elements from SMEM.
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
|
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
|
if (srcIdx < smemFinalDstIdx[0]) {
|
||||||
|
finalLogits[ii] = smemFinal.items.logits[srcIdx];
|
||||||
|
finalIndices[ii] = smemFinal.items.indices[srcIdx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Make sure the shared memory has been read.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Sort the elements.
|
||||||
|
FinalSort(smemFinal.finalSort)
|
||||||
|
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
|
||||||
|
|
||||||
|
// Copy the data back to the shared memory storage.
|
||||||
|
int baseIdx = smemFoundTopKValues[0];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
|
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
|
int dstIdx = baseIdx + srcIdx;
|
||||||
|
|
||||||
|
if (dstIdx < topK) {
|
||||||
|
smemOutput[dstIdx] = finalIndices[ii];
|
||||||
|
if constexpr (multipleBlocksPerRow) {
|
||||||
|
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] =
|
||||||
|
finalLogits[ii];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Sorting with insertion sort
|
||||||
|
auto baseIdx = smemFoundTopKValues[0];
|
||||||
|
for (int i = threadIdx.x; i < smemFinalDstIdx[0];
|
||||||
|
i += kNumThreadsPerBlock) {
|
||||||
|
int outIndex = 0;
|
||||||
|
auto logit = smemFinal.items.logits[i];
|
||||||
|
for (int j = 0; j < smemFinalDstIdx[0]; j++) {
|
||||||
|
auto otherLogit = smemFinal.items.logits[j];
|
||||||
|
if (logit < otherLogit || (logit == otherLogit && i < j)) {
|
||||||
|
outIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Store if outIndex is in bounds
|
||||||
|
if (outIndex + baseIdx < topK) {
|
||||||
|
smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i];
|
||||||
|
if constexpr (multipleBlocksPerRow) {
|
||||||
|
reinterpret_cast<float*>(smemOutput + topK)[outIndex + baseIdx] =
|
||||||
|
smemFinal.items.logits[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store to global memory.
|
||||||
|
for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) {
|
||||||
|
if constexpr (multipleBlocksPerRow) {
|
||||||
|
outIndices[i] = smemOutput[i];
|
||||||
|
outLogits[i] = reinterpret_cast<float*>(smemOutput + topK)[i];
|
||||||
|
} else {
|
||||||
|
if (stride1 == 1) {
|
||||||
|
// stride1 == 1 will use vectorized_process, which indexes already skip
|
||||||
|
// the rowStart.
|
||||||
|
outIndices[i] = smemOutput[i];
|
||||||
|
} else {
|
||||||
|
outIndices[i] = smemOutput[i] - rowStart;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the elements are in shared memory.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// The logits of the elements to be sorted in the final pass.
|
|
||||||
float finalLogits[kNumFinalItemsPerThread];
|
|
||||||
// The indices of the elements to be sorted in the final pass.
|
|
||||||
int finalIndices[kNumFinalItemsPerThread];
|
|
||||||
|
|
||||||
// Init.
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
|
||||||
finalLogits[ii] = -FLT_MAX;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the elements from SMEM.
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
|
||||||
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
|
||||||
if (srcIdx < smemFinalDstIdx[0]) {
|
|
||||||
finalLogits[ii] = smemFinal.items.logits[srcIdx];
|
|
||||||
finalIndices[ii] = smemFinal.items.indices[srcIdx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the shared memory has been read.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Sort the elements.
|
|
||||||
FinalSort(smemFinal.finalSort)
|
|
||||||
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
|
|
||||||
|
|
||||||
// Copy the data back to the shared memory storage.
|
|
||||||
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
|
||||||
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
|
||||||
int dstIdx = baseIdx + srcIdx;
|
|
||||||
if (dstIdx < kTopK) {
|
|
||||||
smemIndices[dstIdx] = finalIndices[ii];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the data is in shared memory.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Store to global memory.
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
|
||||||
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
|
|
||||||
outIndices[offset] =
|
|
||||||
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumThreadsPerBlock = 512>
|
template <int kNumThreadsPerBlock, bool useRadixSort>
|
||||||
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
|
||||||
const int* rowEnds, int* outIndices,
|
const float* logits, const int* rowStarts, const int* rowEnds,
|
||||||
int stride0, int stride1) {
|
int* outIndices, int stride0, int stride1, const int topK,
|
||||||
|
const int offsetIndex) {
|
||||||
// The number of bins in the histogram.
|
// The number of bins in the histogram.
|
||||||
static constexpr int kNumBins = 512;
|
static constexpr int kNumBins = 2048;
|
||||||
|
|
||||||
// The top-k width.
|
|
||||||
static constexpr int kTopK = 2048;
|
|
||||||
|
|
||||||
// The row computed by this block.
|
// The row computed by this block.
|
||||||
int rowIdx = blockIdx.x;
|
int rowIdx = blockIdx.x + offsetIndex;
|
||||||
|
|
||||||
// The range of logits within the row.
|
// The range of logits within the row.
|
||||||
int rowStart = rowStarts[rowIdx];
|
int rowStart = rowStarts[rowIdx];
|
||||||
int rowEnd = rowEnds[rowIdx];
|
int rowEnd = rowEnds[rowIdx];
|
||||||
|
|
||||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
|
// Local pointers to this block
|
||||||
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
|
outIndices += rowIdx * topK;
|
||||||
|
logits += rowIdx * stride0;
|
||||||
|
|
||||||
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
|
||||||
|
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumThreadsPerBlock = 512>
|
template <int kNumThreadsPerBlock, bool useRadixSort,
|
||||||
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
|
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
|
||||||
int* outIndices, int stride0,
|
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||||
int stride1, int next_n) {
|
const float* logits, const int* seqLens, int* outIndices, int stride0,
|
||||||
|
int stride1, const int topK, int next_n, float* outLogits = nullptr,
|
||||||
|
const int numBlocksToMerge = 0, const int* indices = nullptr) {
|
||||||
// The number of bins in the histogram.
|
// The number of bins in the histogram.
|
||||||
static constexpr int kNumBins = 512;
|
static constexpr int kNumBins = 2048;
|
||||||
|
|
||||||
// The top-k width.
|
|
||||||
static constexpr int kTopK = 2048;
|
|
||||||
|
|
||||||
// The row computed by this block.
|
// The row computed by this block.
|
||||||
int rowIdx = blockIdx.x;
|
int rowIdx = blockIdx.x;
|
||||||
@ -290,8 +574,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
|
|||||||
int seq_len = seqLens[rowIdx / next_n];
|
int seq_len = seqLens[rowIdx / next_n];
|
||||||
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
|
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
|
||||||
|
|
||||||
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
|
// Local pointers to this block
|
||||||
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
|
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||||
|
outIndices += rowIdx * topK;
|
||||||
|
} else if constexpr (multipleBlocksPerRow) {
|
||||||
|
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
|
||||||
|
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
|
||||||
|
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
|
||||||
|
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||||
|
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
|
||||||
|
} else if constexpr (mergeBlocks) {
|
||||||
|
rowEnd = numBlocksToMerge * topK;
|
||||||
|
indices += rowIdx * numBlocksToMerge * topK;
|
||||||
|
outIndices += rowIdx * topK;
|
||||||
|
}
|
||||||
|
logits += rowIdx * stride0;
|
||||||
|
|
||||||
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
|
||||||
|
multipleBlocksPerRow, mergeBlocks>(
|
||||||
|
indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -339,28 +640,84 @@ void apply_repetition_penalties_(
|
|||||||
|
|
||||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||||
const torch::Tensor& seqLens, torch::Tensor& indices,
|
const torch::Tensor& seqLens, torch::Tensor& indices,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1) {
|
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||||
// Compute the results on the device.
|
int64_t topK) {
|
||||||
|
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||||
|
constexpr int kSplitWorkThreshold = 200 * 1000;
|
||||||
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
const auto numColumns = logits.size(1);
|
||||||
|
|
||||||
|
if (numColumns < kSortingAlgorithmThreshold) {
|
||||||
|
// Use insertion sort
|
||||||
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||||
|
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||||
|
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||||
|
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
|
static_cast<int>(next_n));
|
||||||
|
} else if (numColumns < kSplitWorkThreshold) {
|
||||||
|
// From this threshold, use radix sort instead
|
||||||
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||||
|
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||||
|
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||||
|
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
|
static_cast<int>(next_n));
|
||||||
|
} else {
|
||||||
|
// Long sequences are run in two steps
|
||||||
|
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||||
|
|
||||||
|
const auto outIndicesAux =
|
||||||
|
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
||||||
|
torch::dtype(torch::kInt32).device(logits.device()));
|
||||||
|
const auto outLogitsAux =
|
||||||
|
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
||||||
|
torch::dtype(torch::kFloat).device(logits.device()));
|
||||||
|
|
||||||
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
|
||||||
|
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
|
||||||
|
2 * topK * sizeof(int32_t), stream>>>(
|
||||||
|
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||||
|
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
|
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
|
||||||
|
|
||||||
|
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||||
|
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||||
|
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||||
|
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||||
|
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
||||||
|
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
|
||||||
|
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void top_k_per_row_prefill(const torch::Tensor& logits,
|
||||||
|
const torch::Tensor& rowStarts,
|
||||||
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
|
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||||
|
int64_t topK) {
|
||||||
|
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||||
constexpr int kNumThreadsPerBlock = 512;
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
vllm::topKPerRowDecode<kNumThreadsPerBlock>
|
int numInsertionBlocks =
|
||||||
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
|
||||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
|
||||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||||
static_cast<int>(stride1), static_cast<int>(next_n));
|
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||||
}
|
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||||
|
static_cast<int>(stride0), static_cast<int>(stride1),
|
||||||
|
static_cast<int>(topK), 0);
|
||||||
|
|
||||||
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
if (numRows > kSortingAlgorithmThreshold) {
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1) {
|
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
|
||||||
// Compute the results on the device.
|
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||||
constexpr int kNumThreadsPerBlock = 512;
|
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||||
|
static_cast<int>(stride0), static_cast<int>(stride1),
|
||||||
vllm::topKPerRow<kNumThreadsPerBlock>
|
static_cast<int>(topK), kSortingAlgorithmThreshold);
|
||||||
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
}
|
||||||
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
|
||||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
|
||||||
static_cast<int>(stride0), static_cast<int>(stride1));
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// Optimized top-k per row operation
|
// Optimized top-k per row operation
|
||||||
ops.def(
|
ops.def(
|
||||||
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||||
"Tensor! indices, int numRows, int stride0, "
|
"Tensor! indices, int numRows, int stride0, "
|
||||||
"int stride1) -> ()");
|
"int stride1, int topK) -> ()");
|
||||||
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
|
||||||
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"top_k_per_row_decode(Tensor logits, int next_n, "
|
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||||
"Tensor seq_lens, Tensor! indices, int numRows, "
|
"Tensor seq_lens, Tensor! indices, "
|
||||||
"int stride0, int stride1) -> ()");
|
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||||
|
|
||||||
// Layernorm-quant
|
// Layernorm-quant
|
||||||
|
|||||||
@ -9,23 +9,45 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
# Test parameters
|
# Test parameters
|
||||||
NUM_ROWS = [1, 32, 2050]
|
NUM_ROWS = [1, 32, 2050]
|
||||||
TOP_K_VALUES = [2048]
|
TOP_K_VALUES = [2048, 3000]
|
||||||
BATCH_SIZE = [1, 2, 4, 2048, 4096]
|
BATCH_SIZE = [1, 2, 2048]
|
||||||
NEXT_N = [1, 2, 4, 8]
|
NEXT_N = [1, 8]
|
||||||
|
DATA_GENERATION = ["random", "10LSBits"]
|
||||||
|
|
||||||
|
|
||||||
def create_random_logits(
|
def create_random_logits(
|
||||||
row_starts: torch.Tensor,
|
row_starts: torch.Tensor,
|
||||||
row_ends: torch.Tensor,
|
row_ends: torch.Tensor,
|
||||||
vocab_size: int,
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
data_generation: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Create random logits tensor for testing."""
|
"""Create random logits tensor for testing."""
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
# Generate logits with some structure to make testing more meaningful
|
# Generate logits with some structure to make testing more meaningful
|
||||||
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda")
|
if data_generation == "random":
|
||||||
|
logits = torch.randn(
|
||||||
|
row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
elif data_generation == "10LSBits":
|
||||||
|
top_22_bits_mask = 0xFFFFFC00
|
||||||
|
last_10_bits_mask = 0x000003FF
|
||||||
|
fixed_top_22_bits = 0x3F900000
|
||||||
|
# Generate random bits for the last 10 bits
|
||||||
|
random_bottom_bits = torch.randint(
|
||||||
|
0,
|
||||||
|
2**10,
|
||||||
|
(row_starts.shape[0], max(row_ends)),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
# Combine: fixed top 22 bits with random last 10 bits
|
||||||
|
logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
|
||||||
|
random_bottom_bits & last_10_bits_mask
|
||||||
|
)
|
||||||
|
logits = logits_bits.view(dtype)
|
||||||
|
|
||||||
for i, end in enumerate(row_ends):
|
for i, end in enumerate(row_ends):
|
||||||
logits[i, end:] = float("-inf")
|
logits[i, end:] = float("-inf")
|
||||||
return logits
|
return logits
|
||||||
@ -113,13 +135,13 @@ def test_top_k_per_row(
|
|||||||
# Create test data
|
# Create test data
|
||||||
vocab_size = 20000
|
vocab_size = 20000
|
||||||
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
|
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
|
||||||
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
|
||||||
|
|
||||||
# Create output tensors
|
# Create output tensors
|
||||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
# Run CUDA implementation
|
# Run CUDA implementation
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row_prefill(
|
||||||
logits,
|
logits,
|
||||||
row_starts,
|
row_starts,
|
||||||
row_ends,
|
row_ends,
|
||||||
@ -127,6 +149,7 @@ def test_top_k_per_row(
|
|||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
|
top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run reference implementation
|
# Run reference implementation
|
||||||
@ -139,27 +162,23 @@ def test_top_k_per_row(
|
|||||||
# Compare results
|
# Compare results
|
||||||
assert compare_top_k_results(
|
assert compare_top_k_results(
|
||||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||||
), "CUDA top_k_per_row results don't match torch.topk"
|
), "CUDA top_k_per_row_prefill results don't match torch.topk"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
def _run_top_k_per_row_decode_test(
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
|
||||||
@pytest.mark.parametrize("next_n", NEXT_N)
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_top_k_per_row_decode(
|
|
||||||
top_k: int,
|
top_k: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
next_n: int,
|
next_n: int,
|
||||||
|
vocab_size: int,
|
||||||
|
data_generation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test top_k_per_row with seq_lens tensor.
|
Helper function to run top_k_per_row_decode test with given parameters.
|
||||||
"""
|
"""
|
||||||
torch.set_default_device("cuda:0")
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
num_rows = batch_size * next_n
|
num_rows = batch_size * next_n
|
||||||
vocab_size = 20000
|
|
||||||
seq_lens = torch.randint(
|
seq_lens = torch.randint(
|
||||||
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
|
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
|
|||||||
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||||
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||||
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||||
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
logits = create_random_logits(
|
||||||
|
row_starts, row_ends, torch.float32, 42, data_generation
|
||||||
|
)
|
||||||
|
|
||||||
# Create output tensors
|
# Create output tensors
|
||||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
|
|||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
|
top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
|
|||||||
# Compare results
|
# Compare results
|
||||||
assert compare_top_k_results(
|
assert compare_top_k_results(
|
||||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||||
), "CUDA top_k_per_row results don't match torch.topk"
|
), "CUDA top_k_per_row_decode results don't match torch.topk"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||||
|
@pytest.mark.parametrize("next_n", NEXT_N)
|
||||||
|
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_top_k_per_row_decode(
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
next_n: int,
|
||||||
|
data_generation: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test top_k_per_row with seq_lens tensor.
|
||||||
|
"""
|
||||||
|
vocab_size = 20000
|
||||||
|
_run_top_k_per_row_decode_test(
|
||||||
|
top_k, batch_size, next_n, vocab_size, data_generation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_top_k_per_row_decode_large_vocab_size() -> None:
|
||||||
|
"""
|
||||||
|
Test top_k_per_row_decode with large vocabulary size.
|
||||||
|
"""
|
||||||
|
top_k = 2048
|
||||||
|
batch_size = 2
|
||||||
|
next_n = 2
|
||||||
|
vocab_size = 300000
|
||||||
|
data_generation = "random"
|
||||||
|
_run_top_k_per_row_decode_test(
|
||||||
|
top_k, batch_size, next_n, vocab_size, data_generation
|
||||||
|
)
|
||||||
|
|||||||
@ -684,11 +684,10 @@ def sparse_attn_indexer(
|
|||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
)
|
)
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
|
||||||
topk_indices = topk_indices_buffer[
|
topk_indices = topk_indices_buffer[
|
||||||
chunk.token_start : chunk.token_end, :topk_tokens
|
chunk.token_start : chunk.token_end, :topk_tokens
|
||||||
]
|
]
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row_prefill(
|
||||||
logits,
|
logits,
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
@ -696,6 +695,7 @@ def sparse_attn_indexer(
|
|||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
|
topk_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
@ -738,7 +738,6 @@ def sparse_attn_indexer(
|
|||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
|
||||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||||
|
|
||||||
torch.ops._C.top_k_per_row_decode(
|
torch.ops._C.top_k_per_row_decode(
|
||||||
@ -749,6 +748,7 @@ def sparse_attn_indexer(
|
|||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
|
topk_tokens,
|
||||||
)
|
)
|
||||||
if decode_metadata.requires_padding:
|
if decode_metadata.requires_padding:
|
||||||
# if padded, we need to unpack
|
# if padded, we need to unpack
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user