[Deepseek v3.2] Optimize top_k_per_row (#26763)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-10-21 10:30:07 +02:00 committed by GitHub
parent c3a2c6ac5f
commit 80e9452984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 13 additions and 49 deletions

View File

@ -99,8 +99,7 @@ void apply_repetition_penalties_(torch::Tensor& logits,
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices, const torch::Tensor& rowEnds, torch::Tensor& indices,
torch::Tensor& values, int64_t numRows, int64_t stride0, int64_t numRows, int64_t stride0, int64_t stride1);
int64_t stride1);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, torch::Tensor& weight, torch::Tensor& scale,

View File

@ -57,7 +57,7 @@ static inline __device__ uint16_t extractBinIdx(float x) {
template <int kNumThreadsPerBlock = 512> template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRow(const float* logits, const int* rowStarts, static __global__ void topKPerRow(const float* logits, const int* rowStarts,
const int* rowEnds, int* outIndices, const int* rowEnds, int* outIndices,
float* outLogits, int stride0, int stride1) { int stride0, int stride1) {
// The number of bins in the histogram. // The number of bins in the histogram.
static constexpr int kNumBins = 512; static constexpr int kNumBins = 512;
@ -103,8 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
__shared__ int smemHistogram[kNumBins]; __shared__ int smemHistogram[kNumBins];
// Shared memory to store the selected indices. // Shared memory to store the selected indices.
__shared__ int smemIndices[kTopK]; __shared__ int smemIndices[kTopK];
// Shared memory to store the selected logits.
__shared__ float smemLogits[kTopK];
// Shared memory to store the threshold bin. // Shared memory to store the threshold bin.
__shared__ int smemThresholdBinIdx[1]; __shared__ int smemThresholdBinIdx[1];
// Shared memory counter to register the candidates for the final phase. // Shared memory counter to register the candidates for the final phase.
@ -124,13 +122,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
rowIt += kNumThreadsPerBlock) { rowIt += kNumThreadsPerBlock) {
int idx = rowStart + rowIt; int idx = rowStart + rowIt;
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
outLogits[rowIdx * kTopK + rowIt] =
logits[rowIdx * stride0 + idx * stride1];
} }
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
rowIt += kNumThreadsPerBlock) { rowIt += kNumThreadsPerBlock) {
outIndices[rowIdx * kTopK + rowIt] = -1; outIndices[rowIdx * kTopK + rowIt] = -1;
outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
} }
return; return;
} }
@ -201,7 +196,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
uint16_t idx = extractBinIdx(logit); uint16_t idx = extractBinIdx(logit);
if (idx < thresholdBinIdx) { if (idx < thresholdBinIdx) {
int dstIdx = atomicAdd(&smemHistogram[idx], 1); int dstIdx = atomicAdd(&smemHistogram[idx], 1);
smemLogits[dstIdx] = logit;
smemIndices[dstIdx] = rowIt; smemIndices[dstIdx] = rowIt;
} else if (idx == thresholdBinIdx) { } else if (idx == thresholdBinIdx) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
@ -250,7 +244,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
int dstIdx = baseIdx + srcIdx; int dstIdx = baseIdx + srcIdx;
if (dstIdx < kTopK) { if (dstIdx < kTopK) {
smemLogits[dstIdx] = finalLogits[ii];
smemIndices[dstIdx] = finalIndices[ii]; smemIndices[dstIdx] = finalIndices[ii];
} }
} }
@ -258,28 +251,12 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
// The topK logits.
float topKLogits[kNumTopKItemsPerThread];
// The topK indices.
int topKIndices[kNumTopKItemsPerThread];
// Load from shared memory.
#pragma unroll
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x];
topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x];
}
// Sort the elements.
TopKSort(smemFinal.topKSort)
.SortDescendingBlockedToStriped(topKLogits, topKIndices);
// Store to global memory. // Store to global memory.
#pragma unroll #pragma unroll
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
outIndices[offset] = topKIndices[ii] - rowStart; outIndices[offset] =
outLogits[offset] = topKLogits[ii]; smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
} }
} }
@ -328,8 +305,7 @@ void apply_repetition_penalties_(
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices, const torch::Tensor& rowEnds, torch::Tensor& indices,
torch::Tensor& values, int64_t numRows, int64_t stride0, int64_t numRows, int64_t stride0, int64_t stride1) {
int64_t stride1) {
// Compute the results on the device. // Compute the results on the device.
constexpr int kNumThreadsPerBlock = 512; constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -338,6 +314,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
<<<numRows, kNumThreadsPerBlock, 0, stream>>>( <<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), rowStarts.data_ptr<int>(), logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(), rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
values.data_ptr<float>(), static_cast<int>(stride0), static_cast<int>(stride0), static_cast<int>(stride1));
static_cast<int>(stride1));
} }

View File

@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation // Optimized top-k per row operation
ops.def( ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, Tensor! values, int numRows, int stride0, " "Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()"); "int stride1) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);

View File

@ -39,10 +39,9 @@ def create_row_boundaries(
def compare_top_k_results( def compare_top_k_results(
logits: torch.Tensor,
cuda_indices: torch.Tensor, cuda_indices: torch.Tensor,
cuda_values: torch.Tensor,
torch_indices: torch.Tensor, torch_indices: torch.Tensor,
torch_values: torch.Tensor,
row_starts: torch.Tensor, row_starts: torch.Tensor,
row_ends: torch.Tensor, row_ends: torch.Tensor,
top_k: int, top_k: int,
@ -70,8 +69,9 @@ def compare_top_k_results(
continue continue
# Any difference in elements, compare the values # Any difference in elements, compare the values
cuda_row_values = cuda_values[row_idx][:num_valid].cpu() logits_row = logits[row_idx]
torch_row_values = torch_values[row_idx][:num_valid].cpu() cuda_row_values = [logits_row[i] for i in cuda_row_indices]
torch_row_values = [logits_row[i] for i in torch_row_indices]
cuda_only_values, torch_only_values = [], [] cuda_only_values, torch_only_values = [], []
for idx in cuda_set - torch_set: for idx in cuda_set - torch_set:
@ -115,7 +115,6 @@ def test_top_k_per_row(
# Create output tensors # Create output tensors
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
# Run CUDA implementation # Run CUDA implementation
torch.ops._C.top_k_per_row( torch.ops._C.top_k_per_row(
@ -123,14 +122,13 @@ def test_top_k_per_row(
row_starts, row_starts,
row_ends, row_ends,
indices, indices,
values,
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
) )
# Run reference implementation # Run reference implementation
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1) torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
mask_lo = torch_indices >= 0 mask_lo = torch_indices >= 0
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
mask = mask_lo & mask_hi mask = mask_lo & mask_hi
@ -138,5 +136,5 @@ def test_top_k_per_row(
# Compare results # Compare results
assert compare_top_k_results( assert compare_top_k_results(
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk" ), "CUDA top_k_per_row results don't match torch.topk"

View File

@ -577,15 +577,11 @@ def sparse_attn_indexer(
topk_indices = torch.empty( topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device num_rows, topk_tokens, dtype=torch.int32, device=logits.device
) )
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row( torch.ops._C.top_k_per_row(
logits, logits,
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
topk_indices, topk_indices,
topk_values,
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
@ -642,15 +638,11 @@ def sparse_attn_indexer(
topk_indices = torch.empty( topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device num_rows, topk_tokens, dtype=torch.int32, device=logits.device
) )
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row( torch.ops._C.top_k_per_row(
logits, logits,
torch.zeros(num_rows, dtype=torch.int32, device=logits.device), torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
index_end_pos.to(dtype=torch.int32, device=logits.device), index_end_pos.to(dtype=torch.int32, device=logits.device),
topk_indices, topk_indices,
topk_values,
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),