#include "dispatch_utils.h" #include #include #ifndef USE_ROCM #include #else #include #endif namespace vllm { template __global__ void apply_repetition_penalties_kernel( scalar_t* __restrict__ logits, // [num_seqs, vocab_size] const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size] const bool* __restrict__ output_mask, // [num_seqs, vocab_size] const scalar_t* __restrict__ repetition_penalties, // [num_seqs] const int num_seqs, const int vocab_size, const int tile_size) { // Each block handles one sequence and a tile of vocab const int seq_idx = blockIdx.x; if (seq_idx >= num_seqs) return; const int tile_start = blockIdx.y * tile_size; const int tile_end = min(tile_start + tile_size, vocab_size); // Load repetition penalty for this sequence const scalar_t penalty = repetition_penalties[seq_idx]; // Each thread processes multiple vocab items within the tile for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end; vocab_idx += blockDim.x) { const int64_t idx = static_cast(seq_idx) * vocab_size + vocab_idx; const bool is_repeated = prompt_mask[idx] || output_mask[idx]; if (is_repeated) { scalar_t logit = logits[idx]; if (logit > 0) { logits[idx] = logit / penalty; } else { logits[idx] = logit * penalty; } } } } static inline __device__ uint16_t extractBinIdx(float x) { union { __half h; uint16_t u16; } tmp; tmp.h = __float2half_rn(x); tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); return 511 - (tmp.u16 >> 7); } template static __global__ void topKPerRow(const float* logits, const int* rowStarts, const int* rowEnds, int* outIndices, float* outLogits, int stride0, int stride1) { // The number of bins in the histogram. static constexpr int kNumBins = 512; // The top-k width. static constexpr int kTopK = 2048; // The number of elements per thread for the final top-k sort. static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; // The class to sort the elements during the final top-k sort. using TopKSort = cub::BlockRadixSort; // The number of slots for the final pass. static constexpr int kNumFinalItems = 3072; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. using FinalSort = cub::BlockRadixSort; // The class to compute the inclusive prefix-sum over the histogram. using Scan = cub::BlockScan; // Shared memory to compute the block scan. __shared__ typename Scan::TempStorage smemScan; // The structure to store the final items (for the final pass). struct FinalItems { // Shared memory to store the indices for the final pass. int indices[kNumFinalItems]; // Shared memory to store the logits for the final pass. float logits[kNumFinalItems]; }; // Shared memory to compute the block sort. __shared__ union { FinalItems items; typename FinalSort::TempStorage finalSort; typename TopKSort::TempStorage topKSort; } smemFinal; // Shared memory to store the histogram. __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. __shared__ int smemIndices[kTopK]; // Shared memory to store the selected logits. __shared__ float smemLogits[kTopK]; // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; // The row computed by this block. int rowIdx = blockIdx.x; // The range of logits within the row. int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx]; // The length of the row. int rowLen = rowEnd - rowStart; // Shortcut if the length of the row is smaller than Top-K. Indices are not // sorted by their corresponding logit. if (rowLen <= kTopK) { for (int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { int idx = rowStart + rowIt; outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; outLogits[rowIdx * kTopK + rowIt] = logits[rowIdx * stride0 + idx * stride1]; } for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; rowIt += kNumThreadsPerBlock) { outIndices[rowIdx * kTopK + rowIt] = -1; outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX; } return; } // Clear the histogram. if (threadIdx.x < kNumBins) { smemHistogram[threadIdx.x] = 0; } // Make sure the histogram is ready. __syncthreads(); // Fetch elements one-by-one. for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; rowIt += kNumThreadsPerBlock) { uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); atomicAdd(&smemHistogram[idx], 1); } // Make sure the histogram is ready. __syncthreads(); // Read the values from SMEM. int binCount{0}; if (threadIdx.x < kNumBins) { binCount = smemHistogram[threadIdx.x]; } // Make sure each thread has read its value. __syncthreads(); // Compute the prefix sum. int prefixSum{0}, totalSum{0}; Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); // Update the histogram with the prefix sums. if (threadIdx.x < kNumBins) { smemHistogram[threadIdx.x] = prefixSum; } // Make sure the data is in shared memory. __syncthreads(); // Find the last valid bin. if (threadIdx.x < kNumBins) { int nextPrefixSum = threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; if (prefixSum < kTopK && nextPrefixSum >= kTopK) { smemThresholdBinIdx[0] = threadIdx.x; } } // Clear the counter to store the items for the final phase. if (threadIdx.x == 0) { smemFinalDstIdx[0] = 0; } // Make sure the data is in shared memory. __syncthreads(); // The threshold bin. int thresholdBinIdx = smemThresholdBinIdx[0]; // Fetch elements one-by-one and populate the shared memory buffers. for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; rowIt += kNumThreadsPerBlock) { float logit = logits[rowIdx * stride0 + rowIt * stride1]; uint16_t idx = extractBinIdx(logit); if (idx < thresholdBinIdx) { int dstIdx = atomicAdd(&smemHistogram[idx], 1); smemLogits[dstIdx] = logit; smemIndices[dstIdx] = rowIt; } else if (idx == thresholdBinIdx) { int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); if (dstIdx < kNumFinalItems) { smemFinal.items.logits[dstIdx] = logit; smemFinal.items.indices[dstIdx] = rowIt; } } } // Make sure the elements are in shared memory. __syncthreads(); // The logits of the elements to be sorted in the final pass. float finalLogits[kNumFinalItemsPerThread]; // The indices of the elements to be sorted in the final pass. int finalIndices[kNumFinalItemsPerThread]; // Init. #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { finalLogits[ii] = -FLT_MAX; } // Read the elements from SMEM. #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; if (srcIdx < smemFinalDstIdx[0]) { finalLogits[ii] = smemFinal.items.logits[srcIdx]; finalIndices[ii] = smemFinal.items.indices[srcIdx]; } } // Make sure the shared memory has been read. __syncthreads(); // Sort the elements. FinalSort(smemFinal.finalSort) .SortDescendingBlockedToStriped(finalLogits, finalIndices); // Copy the data back to the shared memory storage. int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int dstIdx = baseIdx + srcIdx; if (dstIdx < kTopK) { smemLogits[dstIdx] = finalLogits[ii]; smemIndices[dstIdx] = finalIndices[ii]; } } // Make sure the data is in shared memory. __syncthreads(); // The topK logits. float topKLogits[kNumTopKItemsPerThread]; // The topK indices. int topKIndices[kNumTopKItemsPerThread]; // Load from shared memory. #pragma unroll for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x]; topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x]; } // Sort the elements. TopKSort(smemFinal.topKSort) .SortDescendingBlockedToStriped(topKLogits, topKIndices); // Store to global memory. #pragma unroll for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; outIndices[offset] = topKIndices[ii] - rowStart; outLogits[offset] = topKLogits[ii]; } } } // namespace vllm void apply_repetition_penalties_( torch::Tensor& logits, // [num_seqs, vocab_size], in-place const torch::Tensor& prompt_mask, // [num_seqs, vocab_size] const torch::Tensor& output_mask, // [num_seqs, vocab_size] const torch::Tensor& repetition_penalties) { // [num_seqs] TORCH_CHECK(logits.is_contiguous()); TORCH_CHECK(prompt_mask.is_contiguous()); TORCH_CHECK(output_mask.is_contiguous()); TORCH_CHECK(repetition_penalties.is_contiguous()); int vocab_size = logits.size(-1); int num_seqs = logits.size(0); if (num_seqs == 0) return; // Get number of SMs on the current device int sms = 0; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, logits.get_device()); // Compute tile_num and tile_size int tile_num = std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs)); int tile_size = (vocab_size + tile_num - 1) / tile_num; // Each block handles one sequence and a tile of vocab dim3 grid(num_seqs, tile_num); dim3 block(std::min(tile_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(logits)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( logits.scalar_type(), "apply_repetition_penalties_kernel", [&] { vllm::apply_repetition_penalties_kernel <<>>( logits.data_ptr(), prompt_mask.data_ptr(), output_mask.data_ptr(), repetition_penalties.data_ptr(), num_seqs, vocab_size, tile_size); }); } void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, torch::Tensor& values, int64_t numRows, int64_t stride0, int64_t stride1) { // Compute the results on the device. constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::topKPerRow <<>>( logits.data_ptr(), rowStarts.data_ptr(), rowEnds.data_ptr(), indices.data_ptr(), values.data_ptr(), static_cast(stride0), static_cast(stride1)); }