[Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Benjamin Chislett 2025-11-28 17:25:05 -05:00 committed by GitHub
parent 3461e7efd8
commit 1986de1375
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 199 additions and 108 deletions

View File

@ -103,16 +103,23 @@ def test_prepare_next_token_ids():
mock_request.num_computed_tokens = 0
mock_requests[req_id] = mock_request
# explicitly discard the last request
discarded_req_mask = torch.tensor(
[False, False, False, True], dtype=torch.bool, device=device
)
sampled_token_ids = [
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[-1, -1, -1, -1, -1], # this request will be discarded
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
]
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
for i in range(len(sampled_token_ids_cpu)):
if discarded_req_mask[i]:
sampled_token_ids_cpu[i] = []
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
@ -136,9 +143,6 @@ def test_prepare_next_token_ids():
device=device,
)
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
num_discarded_reqs = 1
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
@ -149,8 +153,7 @@ def test_prepare_next_token_ids():
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
discarded_req_indices,
num_discarded_reqs,
discarded_req_mask,
)
)
@ -256,11 +259,6 @@ def test_prepare_inputs_padded():
- Request 3: query_len = 3, rejected = 2
Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.
token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
@ -268,9 +266,6 @@ def test_prepare_inputs_padded():
device = torch.device(current_platform.device_type)
expected_token_indices = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
)
expected_token_indices_to_sample = torch.tensor(
[1, 5, 6], dtype=torch.int32, device=device
)
@ -305,15 +300,12 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices, token_indices_to_sample = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)

View File

@ -25,6 +25,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (
@ -40,6 +41,10 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -555,20 +560,15 @@ class EagleProposer:
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
is not sampled and comes from `request.get_token_id()` instead. This is denoted
the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array(
@ -577,44 +577,39 @@ class EagleProposer:
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
]
],
dtype=np.int32,
)
self.backup_next_token_ids.copy_to_gpu(num_reqs)
backup_tokens_gpu = self.backup_next_token_ids.gpu
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = discard_request_indices[
:num_discarded_requests
]
batch_size, num_tokens = sampled_token_ids.shape
device = sampled_token_ids.device
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1
assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32
next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
valid_sampled_tokens_count = torch.empty(
(batch_size,), dtype=torch.int32, device=device
)
# Generate a mask for all valid tokens within those requests
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Kernel grid: one program per request (row)
grid = (batch_size,)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
# Find the next power of 2 for block sizes
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
eagle_prepare_next_token_padded_kernel[grid](
sampled_token_ids,
discard_request_mask,
backup_tokens_gpu,
next_token_ids,
valid_sampled_tokens_count,
gpu_input_batch.vocab_size,
num_tokens,
batch_size,
sampled_token_ids.stride(0),
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
return next_token_ids, valid_sampled_tokens_count
@ -624,35 +619,35 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat(
[
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:]
- spec_decode_metadata.cu_num_draft_tokens[:-1],
]
num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device
token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu),
# Kernel grid: one program per request (row)
grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count,
common_attn_metadata.query_start_loc,
token_indices_to_sample,
num_reqs,
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
@ -665,16 +660,12 @@ class EagleProposer:
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
token_indices_to_sample = (
common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
)
return spec_common_attn_metadata, token_indices, token_indices_to_sample
return spec_common_attn_metadata, token_indices_to_sample
def propose_tree(
self,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
_SAMPLING_EPS = 1e-5
@ -14,3 +15,107 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None
)
@triton.jit
def eagle_prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
Fused kernel for Eagle prepare_input_padded. This kernel computes the
token index to sample for each request, taking into account the number
of draft tokens and the number of valid sampled tokens (which is one more than
the number of accepted tokens).
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = 0
if req_idx == 0:
num_draft_tokens = cu_draft_curr
else:
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
num_draft_tokens = cu_draft_curr - cu_draft_prev
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
num_rejected_tokens = num_draft_tokens + 1 - valid_count
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
@triton.jit
def eagle_prepare_next_token_padded_kernel(
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask_ptr, # [num_reqs]
backup_next_token_ids_ptr, # [num_reqs]
next_token_ids_ptr, # [num_reqs] (output)
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
vocab_size, # tl.int32
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
num_reqs, # tl.int32
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
):
"""
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
number of valid (1 + accepted) tokens for each request, and the corresponding
"next" token id to sample from during speculative decoding. This is the
"last accepted token" from the sampled tokens, or the backup token if no
tokens were accepted or if the request is marked as discarded.
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Check if this request is discarded.
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
if is_discarded:
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
valid_count = tl.full((), 0, dtype=tl.uint32)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
else:
# Count the number of valid tokens among the sampled tokens.
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
token_mask = token_offs < num_sampled_tokens_per_req
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
valid_count = tl.sum(is_valid_mask)
if valid_count > 0:
# Guaranteed to be well-defined since
# valid_count > 0 implies is_valid_mask is not empty
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
# Select the token at that index, using a sum trick since
# we don't want to load again to access token_ids[last_valid_index].
last_valid_token = tl.sum(
tl.where(token_offs == last_valid_index, token_ids, 0)
)
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
else:
# No valid tokens found, use backup token
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)

View File

@ -488,11 +488,9 @@ class GPUModelRunner(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
)
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
self.discard_request_indices = self._make_buffer(
self.max_num_reqs, dtype=torch.int64
self.discard_request_mask = self._make_buffer(
self.max_num_reqs, dtype=torch.bool
)
self.num_discarded_requests = 0
self.num_decode_draft_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
@ -1369,16 +1367,12 @@ class GPUModelRunner(
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
# Record the index of requests that should not be sampled,
# Record which requests should not be sampled,
# so that we could clear the sampled tokens before returning
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[: self.num_discarded_requests] = (
discard_request_indices
self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np
)
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
self.discard_request_mask.copy_to_gpu(num_reqs)
# Copy the tensors to the GPU.
self._prepare_input_ids(
@ -2548,9 +2542,10 @@ class GPUModelRunner(
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
discard_sampled_tokens_req_indices = self.discard_request_indices.np[
: self.num_discarded_requests
]
num_reqs = self.input_batch.num_reqs
discard_sampled_tokens_req_indices = np.nonzero(
self.discard_request_mask.np[:num_reqs]
)[0]
for i in discard_sampled_tokens_req_indices:
gen = self.input_batch.generators.get(int(i))
if gen is not None:
@ -3131,8 +3126,7 @@ class GPUModelRunner(
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
@ -3335,8 +3329,7 @@ class GPUModelRunner(
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
@ -3363,24 +3356,34 @@ class GPUModelRunner(
sampled_token_ids,
spec_decode_metadata.num_draft_tokens,
)
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices)
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1
)
else:
target_hidden_states = hidden_states[token_indices]
else:
common_attn_metadata, token_indices, token_indices_to_sample = (
common_attn_metadata, token_indices_to_sample = (
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
)
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices)
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1
)
else:
target_hidden_states = hidden_states[token_indices]
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens]
target_positions = self._get_positions(total_num_tokens)
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[:total_num_tokens] for h in aux_hidden_states], dim=-1
)
else:
target_hidden_states = hidden_states[:total_num_tokens]
if self.supports_mm_inputs:
mm_embed_inputs = self._gather_mm_embeddings(