mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:45:15 +08:00
Add topk logits torch op for DS3.2. (#25945)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d100d78eb3
commit
e1098ced95
@ -100,6 +100,11 @@ void apply_repetition_penalties_(torch::Tensor& logits,
|
|||||||
const torch::Tensor& output_mask,
|
const torch::Tensor& output_mask,
|
||||||
const torch::Tensor& repetition_penalties);
|
const torch::Tensor& repetition_penalties);
|
||||||
|
|
||||||
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
|
torch::Tensor& values, int64_t numRows, int64_t stride0,
|
||||||
|
int64_t stride1);
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
double epsilon);
|
double epsilon);
|
||||||
|
|||||||
257
csrc/sampler.cu
257
csrc/sampler.cu
@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline __device__ uint16_t extractBinIdx(float x) {
|
||||||
|
union {
|
||||||
|
__half h;
|
||||||
|
uint16_t u16;
|
||||||
|
} tmp;
|
||||||
|
tmp.h = __float2half_rn(x);
|
||||||
|
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
|
||||||
|
return 511 - (tmp.u16 >> 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int kNumThreadsPerBlock = 512>
|
||||||
|
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
||||||
|
const int* rowEnds, int* outIndices,
|
||||||
|
float* outLogits, int stride0, int stride1) {
|
||||||
|
// The number of bins in the histogram.
|
||||||
|
static constexpr int kNumBins = 512;
|
||||||
|
|
||||||
|
// The top-k width.
|
||||||
|
static constexpr int kTopK = 2048;
|
||||||
|
// The number of elements per thread for the final top-k sort.
|
||||||
|
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
|
||||||
|
// The class to sort the elements during the final top-k sort.
|
||||||
|
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
|
||||||
|
kNumTopKItemsPerThread, int>;
|
||||||
|
|
||||||
|
// The number of slots for the final pass.
|
||||||
|
static constexpr int kNumFinalItems = 3072;
|
||||||
|
// The number of elements per thread for the final sort.
|
||||||
|
static constexpr int kNumFinalItemsPerThread =
|
||||||
|
kNumFinalItems / kNumThreadsPerBlock;
|
||||||
|
// The class to sort the elements during the final pass.
|
||||||
|
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
|
||||||
|
kNumFinalItemsPerThread, int>;
|
||||||
|
|
||||||
|
// The class to compute the inclusive prefix-sum over the histogram.
|
||||||
|
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
|
||||||
|
|
||||||
|
// Shared memory to compute the block scan.
|
||||||
|
__shared__ typename Scan::TempStorage smemScan;
|
||||||
|
|
||||||
|
// The structure to store the final items (for the final pass).
|
||||||
|
struct FinalItems {
|
||||||
|
// Shared memory to store the indices for the final pass.
|
||||||
|
int indices[kNumFinalItems];
|
||||||
|
// Shared memory to store the logits for the final pass.
|
||||||
|
float logits[kNumFinalItems];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Shared memory to compute the block sort.
|
||||||
|
__shared__ union {
|
||||||
|
FinalItems items;
|
||||||
|
typename FinalSort::TempStorage finalSort;
|
||||||
|
typename TopKSort::TempStorage topKSort;
|
||||||
|
} smemFinal;
|
||||||
|
|
||||||
|
// Shared memory to store the histogram.
|
||||||
|
__shared__ int smemHistogram[kNumBins];
|
||||||
|
// Shared memory to store the selected indices.
|
||||||
|
__shared__ int smemIndices[kTopK];
|
||||||
|
// Shared memory to store the selected logits.
|
||||||
|
__shared__ float smemLogits[kTopK];
|
||||||
|
// Shared memory to store the threshold bin.
|
||||||
|
__shared__ int smemThresholdBinIdx[1];
|
||||||
|
// Shared memory counter to register the candidates for the final phase.
|
||||||
|
__shared__ int smemFinalDstIdx[1];
|
||||||
|
|
||||||
|
// The row computed by this block.
|
||||||
|
int rowIdx = blockIdx.x;
|
||||||
|
// The range of logits within the row.
|
||||||
|
int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
|
||||||
|
// The length of the row.
|
||||||
|
int rowLen = rowEnd - rowStart;
|
||||||
|
|
||||||
|
// Shortcut if the length of the row is smaller than Top-K. Indices are not
|
||||||
|
// sorted by their corresponding logit.
|
||||||
|
if (rowLen <= kTopK) {
|
||||||
|
for (int rowIt = threadIdx.x; rowIt < rowLen;
|
||||||
|
rowIt += kNumThreadsPerBlock) {
|
||||||
|
int idx = rowStart + rowIt;
|
||||||
|
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
|
||||||
|
outLogits[rowIdx * kTopK + rowIt] =
|
||||||
|
logits[rowIdx * stride0 + idx * stride1];
|
||||||
|
}
|
||||||
|
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
|
||||||
|
rowIt += kNumThreadsPerBlock) {
|
||||||
|
outIndices[rowIdx * kTopK + rowIt] = -1;
|
||||||
|
outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the histogram.
|
||||||
|
if (threadIdx.x < kNumBins) {
|
||||||
|
smemHistogram[threadIdx.x] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the histogram is ready.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Fetch elements one-by-one.
|
||||||
|
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
|
||||||
|
rowIt += kNumThreadsPerBlock) {
|
||||||
|
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
|
||||||
|
atomicAdd(&smemHistogram[idx], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the histogram is ready.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Read the values from SMEM.
|
||||||
|
int binCount{0};
|
||||||
|
if (threadIdx.x < kNumBins) {
|
||||||
|
binCount = smemHistogram[threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure each thread has read its value.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Compute the prefix sum.
|
||||||
|
int prefixSum{0}, totalSum{0};
|
||||||
|
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
|
||||||
|
|
||||||
|
// Update the histogram with the prefix sums.
|
||||||
|
if (threadIdx.x < kNumBins) {
|
||||||
|
smemHistogram[threadIdx.x] = prefixSum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Find the last valid bin.
|
||||||
|
if (threadIdx.x < kNumBins) {
|
||||||
|
int nextPrefixSum =
|
||||||
|
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
|
||||||
|
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
|
||||||
|
smemThresholdBinIdx[0] = threadIdx.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the counter to store the items for the final phase.
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
smemFinalDstIdx[0] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The threshold bin.
|
||||||
|
int thresholdBinIdx = smemThresholdBinIdx[0];
|
||||||
|
|
||||||
|
// Fetch elements one-by-one and populate the shared memory buffers.
|
||||||
|
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
|
||||||
|
rowIt += kNumThreadsPerBlock) {
|
||||||
|
float logit = logits[rowIdx * stride0 + rowIt * stride1];
|
||||||
|
uint16_t idx = extractBinIdx(logit);
|
||||||
|
if (idx < thresholdBinIdx) {
|
||||||
|
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
|
||||||
|
smemLogits[dstIdx] = logit;
|
||||||
|
smemIndices[dstIdx] = rowIt;
|
||||||
|
} else if (idx == thresholdBinIdx) {
|
||||||
|
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
|
||||||
|
if (dstIdx < kNumFinalItems) {
|
||||||
|
smemFinal.items.logits[dstIdx] = logit;
|
||||||
|
smemFinal.items.indices[dstIdx] = rowIt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the elements are in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The logits of the elements to be sorted in the final pass.
|
||||||
|
float finalLogits[kNumFinalItemsPerThread];
|
||||||
|
// The indices of the elements to be sorted in the final pass.
|
||||||
|
int finalIndices[kNumFinalItemsPerThread];
|
||||||
|
|
||||||
|
// Init.
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
|
finalLogits[ii] = -FLT_MAX;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the elements from SMEM.
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
|
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
|
if (srcIdx < smemFinalDstIdx[0]) {
|
||||||
|
finalLogits[ii] = smemFinal.items.logits[srcIdx];
|
||||||
|
finalIndices[ii] = smemFinal.items.indices[srcIdx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the shared memory has been read.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Sort the elements.
|
||||||
|
FinalSort(smemFinal.finalSort)
|
||||||
|
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
|
||||||
|
|
||||||
|
// Copy the data back to the shared memory storage.
|
||||||
|
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
|
||||||
|
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
|
int dstIdx = baseIdx + srcIdx;
|
||||||
|
if (dstIdx < kTopK) {
|
||||||
|
smemLogits[dstIdx] = finalLogits[ii];
|
||||||
|
smemIndices[dstIdx] = finalIndices[ii];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The topK logits.
|
||||||
|
float topKLogits[kNumTopKItemsPerThread];
|
||||||
|
// The topK indices.
|
||||||
|
int topKIndices[kNumTopKItemsPerThread];
|
||||||
|
|
||||||
|
// Load from shared memory.
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
||||||
|
topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x];
|
||||||
|
topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the elements.
|
||||||
|
TopKSort(smemFinal.topKSort)
|
||||||
|
.SortDescendingBlockedToStriped(topKLogits, topKIndices);
|
||||||
|
|
||||||
|
// Store to global memory.
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
|
||||||
|
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
|
||||||
|
outIndices[offset] = topKIndices[ii] - rowStart;
|
||||||
|
outLogits[offset] = topKLogits[ii];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void apply_repetition_penalties_(
|
void apply_repetition_penalties_(
|
||||||
@ -85,4 +324,20 @@ void apply_repetition_penalties_(
|
|||||||
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
|
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
|
||||||
tile_size);
|
tile_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
|
torch::Tensor& values, int64_t numRows, int64_t stride0,
|
||||||
|
int64_t stride1) {
|
||||||
|
// Compute the results on the device.
|
||||||
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
vllm::topKPerRow<kNumThreadsPerBlock>
|
||||||
|
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
||||||
|
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||||
|
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||||
|
values.data_ptr<float>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1));
|
||||||
|
}
|
||||||
|
|||||||
@ -188,6 +188,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("apply_repetition_penalties_", torch::kCUDA,
|
ops.impl("apply_repetition_penalties_", torch::kCUDA,
|
||||||
&apply_repetition_penalties_);
|
&apply_repetition_penalties_);
|
||||||
|
|
||||||
|
// Optimized top-k per row operation
|
||||||
|
ops.def(
|
||||||
|
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||||
|
"Tensor! indices, Tensor! values, int numRows, int stride0, "
|
||||||
|
"int stride1) -> ()");
|
||||||
|
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
||||||
|
|
||||||
// Layernorm-quant
|
// Layernorm-quant
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
143
tests/kernels/test_top_k_per_row.py
Normal file
143
tests/kernels/test_top_k_per_row.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
NUM_ROWS = [1, 32, 2050]
|
||||||
|
TOP_K_VALUES = [2048]
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_logits(
|
||||||
|
row_starts: torch.Tensor,
|
||||||
|
row_ends: torch.Tensor,
|
||||||
|
vocab_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create random logits tensor for testing."""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
# Generate logits with some structure to make testing more meaningful
|
||||||
|
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda")
|
||||||
|
for i, end in enumerate(row_ends):
|
||||||
|
logits[i, end:] = float("-inf")
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def create_row_boundaries(
|
||||||
|
seq_len: int, vocab_size: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Create row start and end indices for testing."""
|
||||||
|
row_starts = torch.zeros(seq_len, dtype=torch.int32, device="cuda")
|
||||||
|
row_ends = torch.arange(1, seq_len + 1, device="cuda", dtype=torch.int32)
|
||||||
|
return row_starts, row_ends
|
||||||
|
|
||||||
|
|
||||||
|
def compare_top_k_results(
|
||||||
|
cuda_indices: torch.Tensor,
|
||||||
|
cuda_values: torch.Tensor,
|
||||||
|
torch_indices: torch.Tensor,
|
||||||
|
torch_values: torch.Tensor,
|
||||||
|
row_starts: torch.Tensor,
|
||||||
|
row_ends: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
tolerance: float = 1e-5,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Compare results from CUDA top_k_per_row with torch.topk.
|
||||||
|
Both results should be sorted and contain the same top-k elements.
|
||||||
|
"""
|
||||||
|
num_rows = cuda_indices.shape[0]
|
||||||
|
|
||||||
|
for row_idx in range(num_rows):
|
||||||
|
# Get valid elements using row boundaries
|
||||||
|
row_start = row_starts[row_idx].item()
|
||||||
|
row_end = row_ends[row_idx].item()
|
||||||
|
row_length = row_end - row_start
|
||||||
|
num_valid = min(top_k, row_length)
|
||||||
|
cuda_row_indices = cuda_indices[row_idx][:num_valid].cpu()
|
||||||
|
torch_row_indices = torch_indices[row_idx][:num_valid].cpu()
|
||||||
|
|
||||||
|
# Compare the sets of indices first
|
||||||
|
cuda_set = set(cuda_row_indices.tolist())
|
||||||
|
torch_set = set(torch_row_indices.tolist())
|
||||||
|
if cuda_set == torch_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Any difference in elements, compare the values
|
||||||
|
cuda_row_values = cuda_values[row_idx][:num_valid].cpu()
|
||||||
|
torch_row_values = torch_values[row_idx][:num_valid].cpu()
|
||||||
|
|
||||||
|
cuda_only_values, torch_only_values = [], []
|
||||||
|
for idx in cuda_set - torch_set:
|
||||||
|
cuda_pos = (cuda_row_indices == idx).nonzero(as_tuple=True)[0]
|
||||||
|
cuda_only_values.append(cuda_row_values[cuda_pos[0]])
|
||||||
|
|
||||||
|
for idx in torch_set - cuda_set:
|
||||||
|
torch_pos = (torch_row_indices == idx).nonzero(as_tuple=True)[0]
|
||||||
|
torch_only_values.append(torch_row_values[torch_pos[0]])
|
||||||
|
|
||||||
|
if len(cuda_only_values) != len(torch_only_values):
|
||||||
|
return False
|
||||||
|
if not torch.allclose(
|
||||||
|
torch.tensor(cuda_only_values),
|
||||||
|
torch.tensor(torch_only_values),
|
||||||
|
rtol=tolerance,
|
||||||
|
atol=tolerance,
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_rows", NUM_ROWS)
|
||||||
|
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_top_k_per_row(
|
||||||
|
num_rows: int,
|
||||||
|
top_k: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test top_k_per_row.
|
||||||
|
"""
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
vocab_size = 20000
|
||||||
|
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
|
||||||
|
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
||||||
|
|
||||||
|
# Create output tensors
|
||||||
|
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
||||||
|
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
# Run CUDA implementation
|
||||||
|
torch.ops._C.top_k_per_row(
|
||||||
|
logits,
|
||||||
|
row_starts,
|
||||||
|
row_ends,
|
||||||
|
indices,
|
||||||
|
values,
|
||||||
|
num_rows,
|
||||||
|
top_k,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)
|
||||||
|
mask_lo = torch_indices >= 0
|
||||||
|
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||||
|
mask = mask_lo & mask_hi
|
||||||
|
torch_indices = torch_indices.masked_fill(~mask, -1)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
assert compare_top_k_results(
|
||||||
|
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k
|
||||||
|
), "CUDA top_k_per_row results don't match torch.topk"
|
||||||
@ -643,17 +643,24 @@ def sparse_attn_indexer(
|
|||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
)
|
)
|
||||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1]
|
num_rows = logits.shape[0]
|
||||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||||
mask_lo = topk_indices >= 0
|
topk_indices = torch.empty(
|
||||||
mask_hi = (
|
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||||
topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0
|
|
||||||
)
|
)
|
||||||
mask = torch.full_like(
|
topk_values = torch.empty(
|
||||||
topk_indices, False, dtype=torch.bool, device=topk_indices.device
|
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||||
|
)
|
||||||
|
torch.ops._C.top_k_per_row(
|
||||||
|
logits,
|
||||||
|
chunk.cu_seqlen_ks,
|
||||||
|
chunk.cu_seqlen_ke,
|
||||||
|
topk_indices,
|
||||||
|
topk_values,
|
||||||
|
num_rows,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
)
|
)
|
||||||
mask = mask_lo & mask_hi
|
|
||||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
|
||||||
topk_indices_buffer[
|
topk_indices_buffer[
|
||||||
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
|
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
|
||||||
] = topk_indices.to(dtype=torch.int32)
|
] = topk_indices.to(dtype=torch.int32)
|
||||||
@ -693,28 +700,32 @@ def sparse_attn_indexer(
|
|||||||
# padded query len
|
# padded query len
|
||||||
current_device = padded_q_fp8_decode_tokens.device
|
current_device = padded_q_fp8_decode_tokens.device
|
||||||
padded_num_tokens = batch_size * next_n
|
padded_num_tokens = batch_size * next_n
|
||||||
positions = (
|
|
||||||
torch.arange(max_model_len, device=current_device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(batch_size * next_n, -1)
|
|
||||||
)
|
|
||||||
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
|
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
|
||||||
next_n_offset = (
|
next_n_offset = (
|
||||||
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
|
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
|
||||||
% next_n
|
% next_n
|
||||||
)
|
)
|
||||||
index_end_pos = (
|
index_end_pos = (
|
||||||
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset
|
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
# index_end_pos: [B * N, 1]
|
num_rows = logits.shape[0]
|
||||||
mask = positions <= index_end_pos
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||||
# mask: [B * N, L]
|
topk_indices = torch.empty(
|
||||||
logits = logits.masked_fill(~mask, float("-inf"))
|
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||||
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
|
)
|
||||||
# ensure we don't set indices for the top k
|
topk_values = torch.empty(
|
||||||
# that is out of range(masked already)
|
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||||
# this will happen if context length is shorter than K
|
)
|
||||||
topk_indices[topk_indices > index_end_pos] = -1
|
torch.ops._C.top_k_per_row(
|
||||||
|
logits,
|
||||||
|
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
||||||
|
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
||||||
|
topk_indices,
|
||||||
|
topk_values,
|
||||||
|
num_rows,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
)
|
||||||
if decode_metadata.requires_padding:
|
if decode_metadata.requires_padding:
|
||||||
# if padded, we need to unpack
|
# if padded, we need to unpack
|
||||||
# the topk indices removing padded tokens
|
# the topk indices removing padded tokens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user