diff --git a/csrc/ops.h b/csrc/ops.h index 2ada7905da4b..9dd302faf5b8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -100,6 +100,11 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + torch::Tensor& values, int64_t numRows, int64_t stride0, + int64_t stride1); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); diff --git a/csrc/sampler.cu b/csrc/sampler.cu index b0cce2e98d22..bc589d99d04b 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel( } } +static inline __device__ uint16_t extractBinIdx(float x) { + union { + __half h; + uint16_t u16; + } tmp; + tmp.h = __float2half_rn(x); + tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); + return 511 - (tmp.u16 >> 7); +} + +template +static __global__ void topKPerRow(const float* logits, const int* rowStarts, + const int* rowEnds, int* outIndices, + float* outLogits, int stride0, int stride1) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + // The number of elements per thread for the final top-k sort. + static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; + // The class to sort the elements during the final top-k sort. + using TopKSort = cub::BlockRadixSort; + + // The number of slots for the final pass. + static constexpr int kNumFinalItems = 3072; + // The number of elements per thread for the final sort. + static constexpr int kNumFinalItemsPerThread = + kNumFinalItems / kNumThreadsPerBlock; + // The class to sort the elements during the final pass. + using FinalSort = cub::BlockRadixSort; + + // The class to compute the inclusive prefix-sum over the histogram. + using Scan = cub::BlockScan; + + // Shared memory to compute the block scan. + __shared__ typename Scan::TempStorage smemScan; + + // The structure to store the final items (for the final pass). + struct FinalItems { + // Shared memory to store the indices for the final pass. + int indices[kNumFinalItems]; + // Shared memory to store the logits for the final pass. + float logits[kNumFinalItems]; + }; + + // Shared memory to compute the block sort. + __shared__ union { + FinalItems items; + typename FinalSort::TempStorage finalSort; + typename TopKSort::TempStorage topKSort; + } smemFinal; + + // Shared memory to store the histogram. + __shared__ int smemHistogram[kNumBins]; + // Shared memory to store the selected indices. + __shared__ int smemIndices[kTopK]; + // Shared memory to store the selected logits. + __shared__ float smemLogits[kTopK]; + // Shared memory to store the threshold bin. + __shared__ int smemThresholdBinIdx[1]; + // Shared memory counter to register the candidates for the final phase. + __shared__ int smemFinalDstIdx[1]; + + // The row computed by this block. + int rowIdx = blockIdx.x; + // The range of logits within the row. + int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx]; + // The length of the row. + int rowLen = rowEnd - rowStart; + + // Shortcut if the length of the row is smaller than Top-K. Indices are not + // sorted by their corresponding logit. + if (rowLen <= kTopK) { + for (int rowIt = threadIdx.x; rowIt < rowLen; + rowIt += kNumThreadsPerBlock) { + int idx = rowStart + rowIt; + outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + outLogits[rowIdx * kTopK + rowIt] = + logits[rowIdx * stride0 + idx * stride1]; + } + for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + rowIt += kNumThreadsPerBlock) { + outIndices[rowIdx * kTopK + rowIt] = -1; + outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX; + } + return; + } + + // Clear the histogram. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Fetch elements one-by-one. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); + atomicAdd(&smemHistogram[idx], 1); + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Read the values from SMEM. + int binCount{0}; + if (threadIdx.x < kNumBins) { + binCount = smemHistogram[threadIdx.x]; + } + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = prefixSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + if (threadIdx.x < kNumBins) { + int nextPrefixSum = + threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; + if (prefixSum < kTopK && nextPrefixSum >= kTopK) { + smemThresholdBinIdx[0] = threadIdx.x; + } + } + + // Clear the counter to store the items for the final phase. + if (threadIdx.x == 0) { + smemFinalDstIdx[0] = 0; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + int thresholdBinIdx = smemThresholdBinIdx[0]; + + // Fetch elements one-by-one and populate the shared memory buffers. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + float logit = logits[rowIdx * stride0 + rowIt * stride1]; + uint16_t idx = extractBinIdx(logit); + if (idx < thresholdBinIdx) { + int dstIdx = atomicAdd(&smemHistogram[idx], 1); + smemLogits[dstIdx] = logit; + smemIndices[dstIdx] = rowIt; + } else if (idx == thresholdBinIdx) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + if (dstIdx < kNumFinalItems) { + smemFinal.items.logits[dstIdx] = logit; + smemFinal.items.indices[dstIdx] = rowIt; + } + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // The logits of the elements to be sorted in the final pass. + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; + +// Init. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } + +// Read the elements from SMEM. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + + // Make sure the shared memory has been read. + __syncthreads(); + + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); + + // Copy the data back to the shared memory storage. + int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + if (dstIdx < kTopK) { + smemLogits[dstIdx] = finalLogits[ii]; + smemIndices[dstIdx] = finalIndices[ii]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The topK logits. + float topKLogits[kNumTopKItemsPerThread]; + // The topK indices. + int topKIndices[kNumTopKItemsPerThread]; + +// Load from shared memory. +#pragma unroll + for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { + topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x]; + topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x]; + } + + // Sort the elements. + TopKSort(smemFinal.topKSort) + .SortDescendingBlockedToStriped(topKLogits, topKIndices); + +// Store to global memory. +#pragma unroll + for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { + int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; + outIndices[offset] = topKIndices[ii] - rowStart; + outLogits[offset] = topKLogits[ii]; + } +} + } // namespace vllm void apply_repetition_penalties_( @@ -85,4 +324,20 @@ void apply_repetition_penalties_( repetition_penalties.data_ptr(), num_seqs, vocab_size, tile_size); }); -} \ No newline at end of file +} + +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + torch::Tensor& values, int64_t numRows, int64_t stride0, + int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRow + <<>>( + logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + values.data_ptr(), static_cast(stride0), + static_cast(stride1)); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 64a345eb66cc..6f939e2eb403 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -188,6 +188,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("apply_repetition_penalties_", torch::kCUDA, &apply_repetition_penalties_); + // Optimized top-k per row operation + ops.def( + "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "Tensor! indices, Tensor! values, int numRows, int stride0, " + "int stride1) -> ()"); + ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py new file mode 100644 index 000000000000..f52cddc8c370 --- /dev/null +++ b/tests/kernels/test_top_k_per_row.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch + +from vllm.platforms import current_platform + +# Test parameters +NUM_ROWS = [1, 32, 2050] +TOP_K_VALUES = [2048] + + +def create_random_logits( + row_starts: torch.Tensor, + row_ends: torch.Tensor, + vocab_size: int, + dtype: torch.dtype, + seed: int, +) -> torch.Tensor: + """Create random logits tensor for testing.""" + torch.manual_seed(seed) + np.random.seed(seed) + # Generate logits with some structure to make testing more meaningful + logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") + for i, end in enumerate(row_ends): + logits[i, end:] = float("-inf") + return logits + + +def create_row_boundaries( + seq_len: int, vocab_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Create row start and end indices for testing.""" + row_starts = torch.zeros(seq_len, dtype=torch.int32, device="cuda") + row_ends = torch.arange(1, seq_len + 1, device="cuda", dtype=torch.int32) + return row_starts, row_ends + + +def compare_top_k_results( + cuda_indices: torch.Tensor, + cuda_values: torch.Tensor, + torch_indices: torch.Tensor, + torch_values: torch.Tensor, + row_starts: torch.Tensor, + row_ends: torch.Tensor, + top_k: int, + tolerance: float = 1e-5, +) -> bool: + """ + Compare results from CUDA top_k_per_row with torch.topk. + Both results should be sorted and contain the same top-k elements. + """ + num_rows = cuda_indices.shape[0] + + for row_idx in range(num_rows): + # Get valid elements using row boundaries + row_start = row_starts[row_idx].item() + row_end = row_ends[row_idx].item() + row_length = row_end - row_start + num_valid = min(top_k, row_length) + cuda_row_indices = cuda_indices[row_idx][:num_valid].cpu() + torch_row_indices = torch_indices[row_idx][:num_valid].cpu() + + # Compare the sets of indices first + cuda_set = set(cuda_row_indices.tolist()) + torch_set = set(torch_row_indices.tolist()) + if cuda_set == torch_set: + continue + + # Any difference in elements, compare the values + cuda_row_values = cuda_values[row_idx][:num_valid].cpu() + torch_row_values = torch_values[row_idx][:num_valid].cpu() + + cuda_only_values, torch_only_values = [], [] + for idx in cuda_set - torch_set: + cuda_pos = (cuda_row_indices == idx).nonzero(as_tuple=True)[0] + cuda_only_values.append(cuda_row_values[cuda_pos[0]]) + + for idx in torch_set - cuda_set: + torch_pos = (torch_row_indices == idx).nonzero(as_tuple=True)[0] + torch_only_values.append(torch_row_values[torch_pos[0]]) + + if len(cuda_only_values) != len(torch_only_values): + return False + if not torch.allclose( + torch.tensor(cuda_only_values), + torch.tensor(torch_only_values), + rtol=tolerance, + atol=tolerance, + ): + return False + + return True + + +@pytest.mark.parametrize("num_rows", NUM_ROWS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row( + num_rows: int, + top_k: int, +) -> None: + """ + Test top_k_per_row. + """ + torch.set_default_device("cuda:0") + + # Create test data + vocab_size = 20000 + row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) + logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + + # Create output tensors + indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") + values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda") + + # Run CUDA implementation + torch.ops._C.top_k_per_row( + logits, + row_starts, + row_ends, + indices, + values, + num_rows, + top_k, + logits.stride(0), + logits.stride(1), + ) + + # Run reference implementation + torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1) + mask_lo = torch_indices >= 0 + mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 + mask = mask_lo & mask_hi + torch_indices = torch_indices.masked_fill(~mask, -1) + + # Compare results + assert compare_top_k_results( + indices, values, torch_indices, torch_values, row_starts, row_ends, top_k + ), "CUDA top_k_per_row results don't match torch.topk" diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5b05d0e3a532..a2fb0cfe6000 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -643,17 +643,24 @@ def sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1] - topk_indices -= chunk.cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = ( - topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0 + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = torch.empty( + num_rows, topk_tokens, dtype=torch.int32, device=logits.device ) - mask = torch.full_like( - topk_indices, False, dtype=torch.bool, device=topk_indices.device + topk_values = torch.empty( + num_rows, topk_tokens, dtype=logits.dtype, device=logits.device + ) + torch.ops._C.top_k_per_row( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + topk_values, + num_rows, + logits.stride(0), + logits.stride(1), ) - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer[ chunk.token_start : chunk.token_end, : topk_indices.shape[-1] ] = topk_indices.to(dtype=torch.int32) @@ -693,28 +700,32 @@ def sparse_attn_indexer( # padded query len current_device = padded_q_fp8_decode_tokens.device padded_num_tokens = batch_size * next_n - positions = ( - torch.arange(max_model_len, device=current_device) - .unsqueeze(0) - .expand(batch_size * next_n, -1) - ) row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n next_n_offset = ( torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) % next_n ) index_end_pos = ( - decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1 ).unsqueeze(1) - # index_end_pos: [B * N, 1] - mask = positions <= index_end_pos - # mask: [B * N, L] - logits = logits.masked_fill(~mask, float("-inf")) - topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] - # ensure we don't set indices for the top k - # that is out of range(masked already) - # this will happen if context length is shorter than K - topk_indices[topk_indices > index_end_pos] = -1 + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = torch.empty( + num_rows, topk_tokens, dtype=torch.int32, device=logits.device + ) + topk_values = torch.empty( + num_rows, topk_tokens, dtype=logits.dtype, device=logits.device + ) + torch.ops._C.top_k_per_row( + logits, + torch.zeros(num_rows, dtype=torch.int32, device=logits.device), + index_end_pos.to(dtype=torch.int32, device=logits.device), + topk_indices, + topk_values, + num_rows, + logits.stride(0), + logits.stride(1), + ) if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens