diff --git a/csrc/ops.h b/csrc/ops.h index 64abc4922ba6..eb3d60b77e60 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -101,6 +101,10 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, int64_t numRows, int64_t stride0, int64_t stride1); +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seq_lens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 92c8095c71e2..410b8988f493 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) { return 511 - (tmp.u16 >> 7); } -template -static __global__ void topKPerRow(const float* logits, const int* rowStarts, - const int* rowEnds, int* outIndices, - int stride0, int stride1) { - // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; +template +__device__ void topKPerRowJob(const float* logits, const int rowStart, + const int rowEnd, const int rowIdx, + int* outIndices, int stride0, int stride1) { // The number of elements per thread for the final top-k sort. static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; // The class to sort the elements during the final top-k sort. @@ -108,10 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; - // The row computed by this block. - int rowIdx = blockIdx.x; - // The range of logits within the row. - int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx]; // The length of the row. int rowLen = rowEnd - rowStart; @@ -260,6 +251,49 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, } } +template +static __global__ void topKPerRow(const float* logits, const int* rowStarts, + const int* rowEnds, int* outIndices, + int stride0, int stride1) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = rowStarts[rowIdx]; + int rowEnd = rowEnds[rowIdx]; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + +template +static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, + int* outIndices, int stride0, + int stride1, int next_n) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = 0; + int seq_len = seqLens[rowIdx / next_n]; + int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + } // namespace vllm void apply_repetition_penalties_( @@ -303,6 +337,20 @@ void apply_repetition_penalties_( }); } +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(next_n)); +} + void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, int64_t numRows, int64_t stride0, int64_t stride1) { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c710d8ef6537..7e8660349dad 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -189,6 +189,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int stride1) -> ()"); ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + ops.def( + "top_k_per_row_decode(Tensor logits, int next_n, " + "Tensor seq_lens, Tensor! indices, int numRows, " + "int stride0, int stride1) -> ()"); + ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index dc64b0499e68..cadda27b49e9 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -10,6 +10,8 @@ from vllm.platforms import current_platform # Test parameters NUM_ROWS = [1, 32, 2050] TOP_K_VALUES = [2048] +BATCH_SIZE = [1, 2, 4, 2048, 4096] +NEXT_N = [1, 2, 4, 8] def create_random_logits( @@ -114,7 +116,7 @@ def test_top_k_per_row( logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) # Create output tensors - indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run CUDA implementation torch.ops._C.top_k_per_row( @@ -138,3 +140,59 @@ def test_top_k_per_row( assert compare_top_k_results( logits, indices, torch_indices, row_starts, row_ends, top_k ), "CUDA top_k_per_row results don't match torch.topk" + + +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("next_n", NEXT_N) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode( + top_k: int, + batch_size: int, + next_n: int, +) -> None: + """ + Test top_k_per_row with seq_lens tensor. + """ + torch.set_default_device("cuda:0") + + # Create test data + num_rows = batch_size * next_n + vocab_size = 20000 + seq_lens = torch.randint( + vocab_size, (batch_size,), dtype=torch.int32, device="cuda" + ) + row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") + row_indices = torch.arange(num_rows, device="cuda") // next_n + next_n_offset = torch.arange(num_rows, device="cuda") % next_n + row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 + logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + + # Create output tensors + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + + # Run CUDA implementation + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) + + torch.cuda.synchronize() + + # Run reference implementation + torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] + mask_lo = torch_indices >= 0 + mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 + mask = mask_lo & mask_hi + torch_indices = torch_indices.masked_fill(~mask, -1) + + # Compare results + assert compare_top_k_results( + logits, indices, torch_indices, row_starts, row_ends, top_k + ), "CUDA top_k_per_row results don't match torch.topk" diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6e287e087c0e..db7b86ffaf96 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -580,9 +580,9 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] assert topk_tokens == 2048, "top_k_per_row assumes size 2048" - topk_indices = torch.empty( - num_rows, topk_tokens, dtype=torch.int32, device=logits.device - ) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] torch.ops._C.top_k_per_row( logits, chunk.cu_seqlen_ks, @@ -592,9 +592,6 @@ def sparse_attn_indexer( logits.stride(0), logits.stride(1), ) - topk_indices_buffer[ - chunk.token_start : chunk.token_end, : topk_indices.shape[-1] - ] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode @@ -628,26 +625,14 @@ def sparse_attn_indexer( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - # padded query len - current_device = padded_q_fp8_decode_tokens.device - padded_num_tokens = batch_size * next_n - row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n - next_n_offset = ( - torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) - % next_n - ) - index_end_pos = ( - decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1 - ).unsqueeze(1) num_rows = logits.shape[0] assert topk_tokens == 2048, "top_k_per_row assumes size 2048" - topk_indices = torch.empty( - num_rows, topk_tokens, dtype=torch.int32, device=logits.device - ) - torch.ops._C.top_k_per_row( + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + + torch.ops._C.top_k_per_row_decode( logits, - torch.zeros(num_rows, dtype=torch.int32, device=logits.device), - index_end_pos.to(dtype=torch.int32, device=logits.device), + next_n, + decode_metadata.seq_lens, topk_indices, num_rows, logits.stride(0), @@ -660,9 +645,9 @@ def sparse_attn_indexer( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices.to(dtype=torch.int32) - ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) return topk_indices_buffer