mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 12:22:17 +08:00
[Deepseek v3.2] Optimize top_k_per_row (#26763)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
c3a2c6ac5f
commit
80e9452984
@ -99,8 +99,7 @@ void apply_repetition_penalties_(torch::Tensor& logits,
|
|||||||
|
|
||||||
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
torch::Tensor& values, int64_t numRows, int64_t stride0,
|
int64_t numRows, int64_t stride0, int64_t stride1);
|
||||||
int64_t stride1);
|
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
|
|||||||
@ -57,7 +57,7 @@ static inline __device__ uint16_t extractBinIdx(float x) {
|
|||||||
template <int kNumThreadsPerBlock = 512>
|
template <int kNumThreadsPerBlock = 512>
|
||||||
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
||||||
const int* rowEnds, int* outIndices,
|
const int* rowEnds, int* outIndices,
|
||||||
float* outLogits, int stride0, int stride1) {
|
int stride0, int stride1) {
|
||||||
// The number of bins in the histogram.
|
// The number of bins in the histogram.
|
||||||
static constexpr int kNumBins = 512;
|
static constexpr int kNumBins = 512;
|
||||||
|
|
||||||
@ -103,8 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
__shared__ int smemHistogram[kNumBins];
|
__shared__ int smemHistogram[kNumBins];
|
||||||
// Shared memory to store the selected indices.
|
// Shared memory to store the selected indices.
|
||||||
__shared__ int smemIndices[kTopK];
|
__shared__ int smemIndices[kTopK];
|
||||||
// Shared memory to store the selected logits.
|
|
||||||
__shared__ float smemLogits[kTopK];
|
|
||||||
// Shared memory to store the threshold bin.
|
// Shared memory to store the threshold bin.
|
||||||
__shared__ int smemThresholdBinIdx[1];
|
__shared__ int smemThresholdBinIdx[1];
|
||||||
// Shared memory counter to register the candidates for the final phase.
|
// Shared memory counter to register the candidates for the final phase.
|
||||||
@ -124,13 +122,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
rowIt += kNumThreadsPerBlock) {
|
rowIt += kNumThreadsPerBlock) {
|
||||||
int idx = rowStart + rowIt;
|
int idx = rowStart + rowIt;
|
||||||
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
|
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
|
||||||
outLogits[rowIdx * kTopK + rowIt] =
|
|
||||||
logits[rowIdx * stride0 + idx * stride1];
|
|
||||||
}
|
}
|
||||||
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
|
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
|
||||||
rowIt += kNumThreadsPerBlock) {
|
rowIt += kNumThreadsPerBlock) {
|
||||||
outIndices[rowIdx * kTopK + rowIt] = -1;
|
outIndices[rowIdx * kTopK + rowIt] = -1;
|
||||||
outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -201,7 +196,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
uint16_t idx = extractBinIdx(logit);
|
uint16_t idx = extractBinIdx(logit);
|
||||||
if (idx < thresholdBinIdx) {
|
if (idx < thresholdBinIdx) {
|
||||||
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
|
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
|
||||||
smemLogits[dstIdx] = logit;
|
|
||||||
smemIndices[dstIdx] = rowIt;
|
smemIndices[dstIdx] = rowIt;
|
||||||
} else if (idx == thresholdBinIdx) {
|
} else if (idx == thresholdBinIdx) {
|
||||||
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
|
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
|
||||||
@ -250,7 +244,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
int dstIdx = baseIdx + srcIdx;
|
int dstIdx = baseIdx + srcIdx;
|
||||||
if (dstIdx < kTopK) {
|
if (dstIdx < kTopK) {
|
||||||
smemLogits[dstIdx] = finalLogits[ii];
|
|
||||||
smemIndices[dstIdx] = finalIndices[ii];
|
smemIndices[dstIdx] = finalIndices[ii];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,28 +251,12 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
// Make sure the data is in shared memory.
|
// Make sure the data is in shared memory.
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// The topK logits.
|
|
||||||
float topKLogits[kNumTopKItemsPerThread];
|
|
||||||
// The topK indices.
|
|
||||||
int topKIndices[kNumTopKItemsPerThread];
|
|
||||||
|
|
||||||
// Load from shared memory.
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
|
||||||
topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x];
|
|
||||||
topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort the elements.
|
|
||||||
TopKSort(smemFinal.topKSort)
|
|
||||||
.SortDescendingBlockedToStriped(topKLogits, topKIndices);
|
|
||||||
|
|
||||||
// Store to global memory.
|
// Store to global memory.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
||||||
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
|
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
outIndices[offset] = topKIndices[ii] - rowStart;
|
outIndices[offset] =
|
||||||
outLogits[offset] = topKLogits[ii];
|
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,8 +305,7 @@ void apply_repetition_penalties_(
|
|||||||
|
|
||||||
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
torch::Tensor& values, int64_t numRows, int64_t stride0,
|
int64_t numRows, int64_t stride0, int64_t stride1) {
|
||||||
int64_t stride1) {
|
|
||||||
// Compute the results on the device.
|
// Compute the results on the device.
|
||||||
constexpr int kNumThreadsPerBlock = 512;
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
@ -338,6 +314,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
|||||||
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
||||||
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||||
values.data_ptr<float>(), static_cast<int>(stride0),
|
static_cast<int>(stride0), static_cast<int>(stride1));
|
||||||
static_cast<int>(stride1));
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// Optimized top-k per row operation
|
// Optimized top-k per row operation
|
||||||
ops.def(
|
ops.def(
|
||||||
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||||
"Tensor! indices, Tensor! values, int numRows, int stride0, "
|
"Tensor! indices, int numRows, int stride0, "
|
||||||
"int stride1) -> ()");
|
"int stride1) -> ()");
|
||||||
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
||||||
|
|
||||||
|
|||||||
@ -39,10 +39,9 @@ def create_row_boundaries(
|
|||||||
|
|
||||||
|
|
||||||
def compare_top_k_results(
|
def compare_top_k_results(
|
||||||
|
logits: torch.Tensor,
|
||||||
cuda_indices: torch.Tensor,
|
cuda_indices: torch.Tensor,
|
||||||
cuda_values: torch.Tensor,
|
|
||||||
torch_indices: torch.Tensor,
|
torch_indices: torch.Tensor,
|
||||||
torch_values: torch.Tensor,
|
|
||||||
row_starts: torch.Tensor,
|
row_starts: torch.Tensor,
|
||||||
row_ends: torch.Tensor,
|
row_ends: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
@ -70,8 +69,9 @@ def compare_top_k_results(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Any difference in elements, compare the values
|
# Any difference in elements, compare the values
|
||||||
cuda_row_values = cuda_values[row_idx][:num_valid].cpu()
|
logits_row = logits[row_idx]
|
||||||
torch_row_values = torch_values[row_idx][:num_valid].cpu()
|
cuda_row_values = [logits_row[i] for i in cuda_row_indices]
|
||||||
|
torch_row_values = [logits_row[i] for i in torch_row_indices]
|
||||||
|
|
||||||
cuda_only_values, torch_only_values = [], []
|
cuda_only_values, torch_only_values = [], []
|
||||||
for idx in cuda_set - torch_set:
|
for idx in cuda_set - torch_set:
|
||||||
@ -115,7 +115,6 @@ def test_top_k_per_row(
|
|||||||
|
|
||||||
# Create output tensors
|
# Create output tensors
|
||||||
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
||||||
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
# Run CUDA implementation
|
# Run CUDA implementation
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row(
|
||||||
@ -123,14 +122,13 @@ def test_top_k_per_row(
|
|||||||
row_starts,
|
row_starts,
|
||||||
row_ends,
|
row_ends,
|
||||||
indices,
|
indices,
|
||||||
values,
|
|
||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run reference implementation
|
# Run reference implementation
|
||||||
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)
|
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||||
mask_lo = torch_indices >= 0
|
mask_lo = torch_indices >= 0
|
||||||
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||||
mask = mask_lo & mask_hi
|
mask = mask_lo & mask_hi
|
||||||
@ -138,5 +136,5 @@ def test_top_k_per_row(
|
|||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
assert compare_top_k_results(
|
assert compare_top_k_results(
|
||||||
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k
|
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||||
), "CUDA top_k_per_row results don't match torch.topk"
|
), "CUDA top_k_per_row results don't match torch.topk"
|
||||||
|
|||||||
@ -577,15 +577,11 @@ def sparse_attn_indexer(
|
|||||||
topk_indices = torch.empty(
|
topk_indices = torch.empty(
|
||||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||||
)
|
)
|
||||||
topk_values = torch.empty(
|
|
||||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
|
||||||
)
|
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row(
|
||||||
logits,
|
logits,
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
topk_indices,
|
topk_indices,
|
||||||
topk_values,
|
|
||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
@ -642,15 +638,11 @@ def sparse_attn_indexer(
|
|||||||
topk_indices = torch.empty(
|
topk_indices = torch.empty(
|
||||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||||
)
|
)
|
||||||
topk_values = torch.empty(
|
|
||||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
|
||||||
)
|
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row(
|
||||||
logits,
|
logits,
|
||||||
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
||||||
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
||||||
topk_indices,
|
topk_indices,
|
||||||
topk_values,
|
|
||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user