From 1986de137502d0d767cb4c1d3cad23dedbd22397 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Fri, 28 Nov 2025 17:25:05 -0500 Subject: [PATCH] [Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597) Signed-off-by: Benjamin Chislett Signed-off-by: Benjamin Chislett --- tests/v1/spec_decode/test_eagle.py | 30 +++----- vllm/v1/spec_decode/eagle.py | 109 +++++++++++++---------------- vllm/v1/spec_decode/utils.py | 105 +++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 63 +++++++++-------- 4 files changed, 199 insertions(+), 108 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index c93c59d1f4c42..9436ab471c21b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -103,16 +103,23 @@ def test_prepare_next_token_ids(): mock_request.num_computed_tokens = 0 mock_requests[req_id] = mock_request + # explicitly discard the last request + discarded_req_mask = torch.tensor( + [False, False, False, True], dtype=torch.bool, device=device + ) sampled_token_ids = [ [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled [0, 1, 2, 3, 4], # all accepted, "4" sampled [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" - [-1, -1, -1, -1, -1], # this request will be discarded + [0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored ] sampled_token_ids_tensor = torch.tensor( sampled_token_ids, dtype=torch.int32, device=device ) sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] + for i in range(len(sampled_token_ids_cpu)): + if discarded_req_mask[i]: + sampled_token_ids_cpu[i] = [] expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_tensor = torch.tensor( @@ -136,9 +143,6 @@ def test_prepare_next_token_ids(): device=device, ) - discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) - num_discarded_reqs = 1 - expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device ) @@ -149,8 +153,7 @@ def test_prepare_next_token_ids(): sampled_token_ids_tensor, mock_requests, mock_input_batch, - discarded_req_indices, - num_discarded_reqs, + discarded_req_mask, ) ) @@ -256,11 +259,6 @@ def test_prepare_inputs_padded(): - Request 3: query_len = 3, rejected = 2 Expected outputs: - token_indices: [0, 1, 2, - 3, 4, 5, - 6, 7, 8] - Reason: Deferred computation should not disturb the original indices. - token_indices_to_sample: [1, 5, 6] Reason: After accounting for rejections, these are the valid token positions from the original indices to sample from. @@ -268,9 +266,6 @@ def test_prepare_inputs_padded(): device = torch.device(current_platform.device_type) - expected_token_indices = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device - ) expected_token_indices_to_sample = torch.tensor( [1, 5, 6], dtype=torch.int32, device=device ) @@ -305,15 +300,12 @@ def test_prepare_inputs_padded(): proposer = _create_proposer("eagle", num_speculative_tokens) - output_metadata, token_indices, token_indices_to_sample = ( - proposer.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count - ) + output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count ) assert output_metadata.max_query_len == 3 assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) - assert torch.equal(token_indices, expected_token_indices) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 305abdade8da6..72f9d15bc1328 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -25,6 +25,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform +from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.tree_attn import ( @@ -40,6 +41,10 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.utils import ( + eagle_prepare_inputs_padded_kernel, + eagle_prepare_next_token_padded_kernel, +) from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -555,20 +560,15 @@ class EagleProposer: sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int, + discard_request_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids and the number of valid sampled tokens for each request, considering the "discarded" requests whose next token - is not sampled and comes from `request.get_token_id()` instead. - It also accounts for the rejected tokens in `sampled_token_ids`. - This function must use device functions to operate on the inputs, and - should not introduce any blocking CPU-GPU synchronization. + is not sampled and comes from `request.get_token_id()` instead. This is denoted + the "backup" token id. It also counts rejected tokens via `sampled_token_ids`. """ - # TODO(Ben): Combine this into a custom fused kernel - # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs self.backup_next_token_ids.np[:num_reqs] = np.array( @@ -577,44 +577,39 @@ class EagleProposer: common_attn_metadata.seq_lens_cpu[i].item() ) for i in range(num_reqs) - ] + ], + dtype=np.int32, ) self.backup_next_token_ids.copy_to_gpu(num_reqs) + backup_tokens_gpu = self.backup_next_token_ids.gpu - # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = discard_request_indices[ - :num_discarded_requests - ] + batch_size, num_tokens = sampled_token_ids.shape + device = sampled_token_ids.device - valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1 + assert discard_request_mask.dtype == torch.bool + assert backup_tokens_gpu.dtype == torch.int32 + + next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device) + valid_sampled_tokens_count = torch.empty( + (batch_size,), dtype=torch.int32, device=device ) - # Generate a mask for all valid tokens within those requests - valid_mask = (valid_sampled_token_ids_gpu != -1) & ( - valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size - ) + # Kernel grid: one program per request (row) + grid = (batch_size,) - # Count the number of valid tokens in each request - valid_sampled_tokens_count = valid_mask.sum(dim=1) - - # Get the rightmost valid index per row - last_valid_indices = valid_sampled_tokens_count - 1 - last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) - - # Get last valid token from each row - # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) - ).squeeze(1) - - # Use last token if valid, pre-computed backup if not - batch_size = valid_sampled_token_ids_gpu.shape[0] - next_token_ids = torch.where( - last_valid_indices != -1, - selected_tokens, - self.backup_next_token_ids.gpu[:batch_size], + # Find the next power of 2 for block sizes + BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens) + eagle_prepare_next_token_padded_kernel[grid]( + sampled_token_ids, + discard_request_mask, + backup_tokens_gpu, + next_token_ids, + valid_sampled_tokens_count, + gpu_input_batch.vocab_size, + num_tokens, + batch_size, + sampled_token_ids.stride(0), + BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS, ) return next_token_ids, valid_sampled_tokens_count @@ -624,35 +619,35 @@ class EagleProposer: common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, - ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, but does not consider the rejected tokens. Instead, all tokens are included as inputs to the speculator, with the rejected tokens used as padding and filtered out later by `token_indices_to_sample`. - No blocking CPU operations should be introduced in this function. """ - num_draft_tokens_gpu = torch.cat( - [ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1], - ] + num_reqs = common_attn_metadata.num_reqs + device = valid_sampled_tokens_count.device + + token_indices_to_sample = torch.empty( + (num_reqs,), dtype=torch.int32, device=device ) - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu), + # Kernel grid: one program per request (row) + grid = (num_reqs,) + eagle_prepare_inputs_padded_kernel[grid]( + spec_decode_metadata.cu_num_draft_tokens, + valid_sampled_tokens_count, + common_attn_metadata.query_start_loc, + token_indices_to_sample, + num_reqs, ) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() - token_indices = self.arange[:total_num_tokens] spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, @@ -665,16 +660,12 @@ class EagleProposer: max_query_len=new_query_len_per_req.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], + slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], causal=True, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) - token_indices_to_sample = ( - common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu - ) - - return spec_common_attn_metadata, token_indices, token_indices_to_sample + return spec_common_attn_metadata, token_indices_to_sample def propose_tree( self, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 1901c6fc9f14f..9d4399d00487a 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.sampling_params import SamplingParams +from vllm.triton_utils import tl, triton _SAMPLING_EPS = 1e-5 @@ -14,3 +15,107 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None ) + + +@triton.jit +def eagle_prepare_inputs_padded_kernel( + cu_num_draft_tokens_ptr, # [num_reqs] + valid_sampled_tokens_count_ptr, # [num_reqs] + query_start_loc_gpu_ptr, # [num_reqs + 1] + token_indices_to_sample_ptr, # [num_reqs] (output) + num_reqs, # tl.int32 +): + """ + Fused kernel for Eagle prepare_input_padded. This kernel computes the + token index to sample for each request, taking into account the number + of draft tokens and the number of valid sampled tokens (which is one more than + the number of accepted tokens). + """ + req_idx = tl.program_id(axis=0) + if req_idx >= num_reqs: + return + + # Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive + # cumulative sum (first entry is the first value, not zero). + cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx) + + num_draft_tokens = 0 + if req_idx == 0: + num_draft_tokens = cu_draft_curr + else: + cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + num_draft_tokens = cu_draft_curr - cu_draft_prev + + valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx) + num_rejected_tokens = num_draft_tokens + 1 - valid_count + num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0) + + # query_start_loc[req_idx + 1] is the start position of the next request, + # which is one past the last token of this request. + q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1 + + index_to_sample = q_last_tok_idx - num_rejected_tokens + tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample) + + +@triton.jit +def eagle_prepare_next_token_padded_kernel( + sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req] + discard_request_mask_ptr, # [num_reqs] + backup_next_token_ids_ptr, # [num_reqs] + next_token_ids_ptr, # [num_reqs] (output) + valid_sampled_tokens_count_ptr, # [num_reqs] (output) + vocab_size, # tl.int32 + num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1) + num_reqs, # tl.int32 + stride_sampled_token_ids, # tl.int32 (stride for dim 0) + BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req +): + """ + Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the + number of valid (1 + accepted) tokens for each request, and the corresponding + "next" token id to sample from during speculative decoding. This is the + "last accepted token" from the sampled tokens, or the backup token if no + tokens were accepted or if the request is marked as discarded. + """ + req_idx = tl.program_id(axis=0) + if req_idx >= num_reqs: + return + + # Check if this request is discarded. + is_discarded = tl.load(discard_request_mask_ptr + req_idx) + + if is_discarded: + backup_token = tl.load(backup_next_token_ids_ptr + req_idx) + valid_count = tl.full((), 0, dtype=tl.uint32) + tl.store(next_token_ids_ptr + req_idx, backup_token) + tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count) + else: + # Count the number of valid tokens among the sampled tokens. + token_offs = tl.arange(0, BLOCK_SIZE_TOKENS) + token_mask = token_offs < num_sampled_tokens_per_req + + row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids + token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1) + + # Rejected tokens are -1, valid tokens are in [0, vocab_size) + is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask + valid_count = tl.sum(is_valid_mask) + + if valid_count > 0: + # Guaranteed to be well-defined since + # valid_count > 0 implies is_valid_mask is not empty + last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1)) + + # Select the token at that index, using a sum trick since + # we don't want to load again to access token_ids[last_valid_index]. + last_valid_token = tl.sum( + tl.where(token_offs == last_valid_index, token_ids, 0) + ) + tl.store(next_token_ids_ptr + req_idx, last_valid_token) + else: + # No valid tokens found, use backup token + backup_token = tl.load(backup_next_token_ids_ptr + req_idx) + tl.store(next_token_ids_ptr + req_idx, backup_token) + + tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6bff83658b45a..9b0fb07297ac3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -488,11 +488,9 @@ class GPUModelRunner( self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False ) self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) - self.discard_request_indices = self._make_buffer( - self.max_num_reqs, dtype=torch.int64 + self.discard_request_mask = self._make_buffer( + self.max_num_reqs, dtype=torch.bool ) - self.num_discarded_requests = 0 - self.num_decode_draft_tokens = self._make_buffer( self.max_num_reqs, dtype=torch.int32 ) @@ -1369,16 +1367,12 @@ class GPUModelRunner( num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) - # Record the index of requests that should not be sampled, + # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np - discard_request_indices = np.nonzero(discard_requests_mask)[0] - self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[: self.num_discarded_requests] = ( - discard_request_indices + self.discard_request_mask.np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] < num_tokens_np ) - - self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + self.discard_request_mask.copy_to_gpu(num_reqs) # Copy the tensors to the GPU. self._prepare_input_ids( @@ -2548,9 +2542,10 @@ class GPUModelRunner( if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = self.discard_request_indices.np[ - : self.num_discarded_requests - ] + num_reqs = self.input_batch.num_reqs + discard_sampled_tokens_req_indices = np.nonzero( + self.discard_request_mask.np[:num_reqs] + )[0] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -3131,8 +3126,7 @@ class GPUModelRunner( sampled_token_ids, self.requests, self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests, + self.discard_request_mask.gpu, ) ) self._copy_valid_sampled_token_count( @@ -3335,8 +3329,7 @@ class GPUModelRunner( sampled_token_ids, self.requests, self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests, + self.discard_request_mask.gpu, ) ) self._copy_valid_sampled_token_count( @@ -3363,24 +3356,34 @@ class GPUModelRunner( sampled_token_ids, spec_decode_metadata.num_draft_tokens, ) + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[token_indices] else: - common_attn_metadata, token_indices, token_indices_to_sample = ( + common_attn_metadata, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count, ) ) - - target_token_ids = self.input_ids.gpu[token_indices] - target_positions = self._get_positions(token_indices) - if self.use_aux_hidden_state_outputs: - assert aux_hidden_states is not None - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1 - ) - else: - target_hidden_states = hidden_states[token_indices] + total_num_tokens = common_attn_metadata.num_actual_tokens + # When padding the batch, token_indices is just a range + target_token_ids = self.input_ids.gpu[:total_num_tokens] + target_positions = self._get_positions(total_num_tokens) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[:total_num_tokens] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[:total_num_tokens] if self.supports_mm_inputs: mm_embed_inputs = self._gather_mm_embeddings(