mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:45:01 +08:00
[Deepseek v3.2] Remove extra logics in indexer (#26465)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Lain <siyuanf@nvidia.com> Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
6c2eef5a5d
commit
09a7e6f617
@ -101,6 +101,10 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
|||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1);
|
int64_t numRows, int64_t stride0, int64_t stride1);
|
||||||
|
|
||||||
|
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||||
|
const torch::Tensor& seq_lens, torch::Tensor& indices,
|
||||||
|
int64_t numRows, int64_t stride0, int64_t stride1);
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
double epsilon);
|
double epsilon);
|
||||||
|
|||||||
@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) {
|
|||||||
return 511 - (tmp.u16 >> 7);
|
return 511 - (tmp.u16 >> 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumThreadsPerBlock = 512>
|
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
|
||||||
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
__device__ void topKPerRowJob(const float* logits, const int rowStart,
|
||||||
const int* rowEnds, int* outIndices,
|
const int rowEnd, const int rowIdx,
|
||||||
int stride0, int stride1) {
|
int* outIndices, int stride0, int stride1) {
|
||||||
// The number of bins in the histogram.
|
|
||||||
static constexpr int kNumBins = 512;
|
|
||||||
|
|
||||||
// The top-k width.
|
|
||||||
static constexpr int kTopK = 2048;
|
|
||||||
// The number of elements per thread for the final top-k sort.
|
// The number of elements per thread for the final top-k sort.
|
||||||
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
|
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
|
||||||
// The class to sort the elements during the final top-k sort.
|
// The class to sort the elements during the final top-k sort.
|
||||||
@ -108,10 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
// Shared memory counter to register the candidates for the final phase.
|
// Shared memory counter to register the candidates for the final phase.
|
||||||
__shared__ int smemFinalDstIdx[1];
|
__shared__ int smemFinalDstIdx[1];
|
||||||
|
|
||||||
// The row computed by this block.
|
|
||||||
int rowIdx = blockIdx.x;
|
|
||||||
// The range of logits within the row.
|
|
||||||
int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
|
|
||||||
// The length of the row.
|
// The length of the row.
|
||||||
int rowLen = rowEnd - rowStart;
|
int rowLen = rowEnd - rowStart;
|
||||||
|
|
||||||
@ -260,6 +251,49 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int kNumThreadsPerBlock = 512>
|
||||||
|
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
|
||||||
|
const int* rowEnds, int* outIndices,
|
||||||
|
int stride0, int stride1) {
|
||||||
|
// The number of bins in the histogram.
|
||||||
|
static constexpr int kNumBins = 512;
|
||||||
|
|
||||||
|
// The top-k width.
|
||||||
|
static constexpr int kTopK = 2048;
|
||||||
|
|
||||||
|
// The row computed by this block.
|
||||||
|
int rowIdx = blockIdx.x;
|
||||||
|
|
||||||
|
// The range of logits within the row.
|
||||||
|
int rowStart = rowStarts[rowIdx];
|
||||||
|
int rowEnd = rowEnds[rowIdx];
|
||||||
|
|
||||||
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
|
||||||
|
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int kNumThreadsPerBlock = 512>
|
||||||
|
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
|
||||||
|
int* outIndices, int stride0,
|
||||||
|
int stride1, int next_n) {
|
||||||
|
// The number of bins in the histogram.
|
||||||
|
static constexpr int kNumBins = 512;
|
||||||
|
|
||||||
|
// The top-k width.
|
||||||
|
static constexpr int kTopK = 2048;
|
||||||
|
|
||||||
|
// The row computed by this block.
|
||||||
|
int rowIdx = blockIdx.x;
|
||||||
|
|
||||||
|
// The range of logits within the row.
|
||||||
|
int rowStart = 0;
|
||||||
|
int seq_len = seqLens[rowIdx / next_n];
|
||||||
|
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
|
||||||
|
|
||||||
|
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
|
||||||
|
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void apply_repetition_penalties_(
|
void apply_repetition_penalties_(
|
||||||
@ -303,6 +337,20 @@ void apply_repetition_penalties_(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||||
|
const torch::Tensor& seqLens, torch::Tensor& indices,
|
||||||
|
int64_t numRows, int64_t stride0, int64_t stride1) {
|
||||||
|
// Compute the results on the device.
|
||||||
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
vllm::topKPerRowDecode<kNumThreadsPerBlock>
|
||||||
|
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
|
||||||
|
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||||
|
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1), static_cast<int>(next_n));
|
||||||
|
}
|
||||||
|
|
||||||
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1) {
|
int64_t numRows, int64_t stride0, int64_t stride1) {
|
||||||
|
|||||||
@ -189,6 +189,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int stride1) -> ()");
|
"int stride1) -> ()");
|
||||||
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||||
|
"Tensor seq_lens, Tensor! indices, int numRows, "
|
||||||
|
"int stride0, int stride1) -> ()");
|
||||||
|
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||||
|
|
||||||
// Layernorm-quant
|
// Layernorm-quant
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from vllm.platforms import current_platform
|
|||||||
# Test parameters
|
# Test parameters
|
||||||
NUM_ROWS = [1, 32, 2050]
|
NUM_ROWS = [1, 32, 2050]
|
||||||
TOP_K_VALUES = [2048]
|
TOP_K_VALUES = [2048]
|
||||||
|
BATCH_SIZE = [1, 2, 4, 2048, 4096]
|
||||||
|
NEXT_N = [1, 2, 4, 8]
|
||||||
|
|
||||||
|
|
||||||
def create_random_logits(
|
def create_random_logits(
|
||||||
@ -114,7 +116,7 @@ def test_top_k_per_row(
|
|||||||
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
||||||
|
|
||||||
# Create output tensors
|
# Create output tensors
|
||||||
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
# Run CUDA implementation
|
# Run CUDA implementation
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row(
|
||||||
@ -138,3 +140,59 @@ def test_top_k_per_row(
|
|||||||
assert compare_top_k_results(
|
assert compare_top_k_results(
|
||||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||||
), "CUDA top_k_per_row results don't match torch.topk"
|
), "CUDA top_k_per_row results don't match torch.topk"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||||
|
@pytest.mark.parametrize("next_n", NEXT_N)
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_top_k_per_row_decode(
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
next_n: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test top_k_per_row with seq_lens tensor.
|
||||||
|
"""
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
num_rows = batch_size * next_n
|
||||||
|
vocab_size = 20000
|
||||||
|
seq_lens = torch.randint(
|
||||||
|
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
|
||||||
|
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||||
|
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||||
|
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||||
|
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
|
||||||
|
|
||||||
|
# Create output tensors
|
||||||
|
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Run CUDA implementation
|
||||||
|
torch.ops._C.top_k_per_row_decode(
|
||||||
|
logits,
|
||||||
|
next_n,
|
||||||
|
seq_lens,
|
||||||
|
indices,
|
||||||
|
num_rows,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||||
|
mask_lo = torch_indices >= 0
|
||||||
|
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||||
|
mask = mask_lo & mask_hi
|
||||||
|
torch_indices = torch_indices.masked_fill(~mask, -1)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
assert compare_top_k_results(
|
||||||
|
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||||
|
), "CUDA top_k_per_row results don't match torch.topk"
|
||||||
|
|||||||
@ -580,9 +580,9 @@ def sparse_attn_indexer(
|
|||||||
)
|
)
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||||
topk_indices = torch.empty(
|
topk_indices = topk_indices_buffer[
|
||||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
chunk.token_start : chunk.token_end, :topk_tokens
|
||||||
)
|
]
|
||||||
torch.ops._C.top_k_per_row(
|
torch.ops._C.top_k_per_row(
|
||||||
logits,
|
logits,
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
@ -592,9 +592,6 @@ def sparse_attn_indexer(
|
|||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
)
|
)
|
||||||
topk_indices_buffer[
|
|
||||||
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
|
|
||||||
] = topk_indices.to(dtype=torch.int32)
|
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_metadata = attn_metadata.decode
|
decode_metadata = attn_metadata.decode
|
||||||
@ -628,26 +625,14 @@ def sparse_attn_indexer(
|
|||||||
decode_metadata.schedule_metadata,
|
decode_metadata.schedule_metadata,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
# padded query len
|
|
||||||
current_device = padded_q_fp8_decode_tokens.device
|
|
||||||
padded_num_tokens = batch_size * next_n
|
|
||||||
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
|
|
||||||
next_n_offset = (
|
|
||||||
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
|
|
||||||
% next_n
|
|
||||||
)
|
|
||||||
index_end_pos = (
|
|
||||||
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
|
|
||||||
).unsqueeze(1)
|
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||||
topk_indices = torch.empty(
|
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
|
||||||
)
|
torch.ops._C.top_k_per_row_decode(
|
||||||
torch.ops._C.top_k_per_row(
|
|
||||||
logits,
|
logits,
|
||||||
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
next_n,
|
||||||
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
decode_metadata.seq_lens,
|
||||||
topk_indices,
|
topk_indices,
|
||||||
num_rows,
|
num_rows,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
@ -660,9 +645,9 @@ def sparse_attn_indexer(
|
|||||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||||
decode_lens,
|
decode_lens,
|
||||||
)
|
)
|
||||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||||
topk_indices.to(dtype=torch.int32)
|
topk_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
return topk_indices_buffer
|
return topk_indices_buffer
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user