From 80e94529845d78284047e52f6253889c534bce13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Tue, 21 Oct 2025 10:30:07 +0200 Subject: [PATCH] [Deepseek v3.2] Optimize top_k_per_row (#26763) Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- csrc/ops.h | 3 +- csrc/sampler.cu | 35 ++++------------------- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_top_k_per_row.py | 14 ++++----- vllm/model_executor/models/deepseek_v2.py | 8 ------ 5 files changed, 13 insertions(+), 49 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index c135a1404294f..64abc4922ba69 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -99,8 +99,7 @@ void apply_repetition_penalties_(torch::Tensor& logits, void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, - torch::Tensor& values, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t numRows, int64_t stride0, int64_t stride1); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, diff --git a/csrc/sampler.cu b/csrc/sampler.cu index bc589d99d04bf..92c8095c71e2a 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -57,7 +57,7 @@ static inline __device__ uint16_t extractBinIdx(float x) { template static __global__ void topKPerRow(const float* logits, const int* rowStarts, const int* rowEnds, int* outIndices, - float* outLogits, int stride0, int stride1) { + int stride0, int stride1) { // The number of bins in the histogram. static constexpr int kNumBins = 512; @@ -103,8 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. __shared__ int smemIndices[kTopK]; - // Shared memory to store the selected logits. - __shared__ float smemLogits[kTopK]; // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. @@ -124,13 +122,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, rowIt += kNumThreadsPerBlock) { int idx = rowStart + rowIt; outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; - outLogits[rowIdx * kTopK + rowIt] = - logits[rowIdx * stride0 + idx * stride1]; } for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; rowIt += kNumThreadsPerBlock) { outIndices[rowIdx * kTopK + rowIt] = -1; - outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX; } return; } @@ -201,7 +196,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, uint16_t idx = extractBinIdx(logit); if (idx < thresholdBinIdx) { int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemLogits[dstIdx] = logit; smemIndices[dstIdx] = rowIt; } else if (idx == thresholdBinIdx) { int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); @@ -250,7 +244,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int dstIdx = baseIdx + srcIdx; if (dstIdx < kTopK) { - smemLogits[dstIdx] = finalLogits[ii]; smemIndices[dstIdx] = finalIndices[ii]; } } @@ -258,28 +251,12 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, // Make sure the data is in shared memory. __syncthreads(); - // The topK logits. - float topKLogits[kNumTopKItemsPerThread]; - // The topK indices. - int topKIndices[kNumTopKItemsPerThread]; - -// Load from shared memory. -#pragma unroll - for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { - topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x]; - topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x]; - } - - // Sort the elements. - TopKSort(smemFinal.topKSort) - .SortDescendingBlockedToStriped(topKLogits, topKIndices); - // Store to global memory. #pragma unroll for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = topKIndices[ii] - rowStart; - outLogits[offset] = topKLogits[ii]; + outIndices[offset] = + smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; } } @@ -328,8 +305,7 @@ void apply_repetition_penalties_( void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, - torch::Tensor& values, int64_t numRows, int64_t stride0, - int64_t stride1) { + int64_t numRows, int64_t stride0, int64_t stride1) { // Compute the results on the device. constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -338,6 +314,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, <<>>( logits.data_ptr(), rowStarts.data_ptr(), rowEnds.data_ptr(), indices.data_ptr(), - values.data_ptr(), static_cast(stride0), - static_cast(stride1)); + static_cast(stride0), static_cast(stride1)); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2bc526097d153..c710d8ef65372 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Optimized top-k per row operation ops.def( "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " - "Tensor! indices, Tensor! values, int numRows, int stride0, " + "Tensor! indices, int numRows, int stride0, " "int stride1) -> ()"); ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index ccef9d7123640..dc64b0499e685 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -39,10 +39,9 @@ def create_row_boundaries( def compare_top_k_results( + logits: torch.Tensor, cuda_indices: torch.Tensor, - cuda_values: torch.Tensor, torch_indices: torch.Tensor, - torch_values: torch.Tensor, row_starts: torch.Tensor, row_ends: torch.Tensor, top_k: int, @@ -70,8 +69,9 @@ def compare_top_k_results( continue # Any difference in elements, compare the values - cuda_row_values = cuda_values[row_idx][:num_valid].cpu() - torch_row_values = torch_values[row_idx][:num_valid].cpu() + logits_row = logits[row_idx] + cuda_row_values = [logits_row[i] for i in cuda_row_indices] + torch_row_values = [logits_row[i] for i in torch_row_indices] cuda_only_values, torch_only_values = [], [] for idx in cuda_set - torch_set: @@ -115,7 +115,6 @@ def test_top_k_per_row( # Create output tensors indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") - values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda") # Run CUDA implementation torch.ops._C.top_k_per_row( @@ -123,14 +122,13 @@ def test_top_k_per_row( row_starts, row_ends, indices, - values, num_rows, logits.stride(0), logits.stride(1), ) # Run reference implementation - torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1) + torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] mask_lo = torch_indices >= 0 mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 mask = mask_lo & mask_hi @@ -138,5 +136,5 @@ def test_top_k_per_row( # Compare results assert compare_top_k_results( - indices, values, torch_indices, torch_values, row_starts, row_ends, top_k + logits, indices, torch_indices, row_starts, row_ends, top_k ), "CUDA top_k_per_row results don't match torch.topk" diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5827e606b4a5e..cdaa26441af31 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -577,15 +577,11 @@ def sparse_attn_indexer( topk_indices = torch.empty( num_rows, topk_tokens, dtype=torch.int32, device=logits.device ) - topk_values = torch.empty( - num_rows, topk_tokens, dtype=logits.dtype, device=logits.device - ) torch.ops._C.top_k_per_row( logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_indices, - topk_values, num_rows, logits.stride(0), logits.stride(1), @@ -642,15 +638,11 @@ def sparse_attn_indexer( topk_indices = torch.empty( num_rows, topk_tokens, dtype=torch.int32, device=logits.device ) - topk_values = torch.empty( - num_rows, topk_tokens, dtype=logits.dtype, device=logits.device - ) torch.ops._C.top_k_per_row( logits, torch.zeros(num_rows, dtype=torch.int32, device=logits.device), index_end_pos.to(dtype=torch.int32, device=logits.device), topk_indices, - topk_values, num_rows, logits.stride(0), logits.stride(1),