From f32c7d6f5455de2684686c7238f9c7ecca6b58b7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Nov 2025 13:54:59 -0800 Subject: [PATCH] [Model Runner V2] Simplify Eagle bookkeeping with num_rejected (#29347) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 19 ++++-------- vllm/v1/worker/gpu/model_runner.py | 30 ++++++++++++++----- vllm/v1/worker/gpu/spec_decode/eagle.py | 19 ++++++------ .../gpu/spec_decode/rejection_sample.py | 12 ++++++++ 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 3ac43ea4952de..2a7048ae3c0e0 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -344,8 +344,8 @@ def _post_update_kernel( sampled_tokens_ptr, sampled_tokens_stride, num_sampled_ptr, + num_rejected_ptr, query_start_loc_ptr, - cu_num_logits_ptr, ): req_id = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + req_id) @@ -360,17 +360,10 @@ def _post_update_kernel( query_start = tl.load(query_start_loc_ptr + req_id) query_end = tl.load(query_start_loc_ptr + req_id + 1) query_len = query_end - query_start + num_rejected = tl.load(num_rejected_ptr + req_id) num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) - num_computed += query_len - # Consider the rejected tokens in spec decoding. - if num_sampled > 0: - # NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills. - logits_start = tl.load(cu_num_logits_ptr + req_id) - logits_end = tl.load(cu_num_logits_ptr + req_id + 1) - num_logits = logits_end - logits_start - num_rejected = num_logits - num_sampled - num_computed -= num_rejected + num_computed += query_len - num_rejected tl.store(num_computed_tokens_ptr + req_state_idx, num_computed) @@ -385,10 +378,10 @@ def post_update( sampled_tokens: torch.Tensor, # [num_reqs] num_sampled: torch.Tensor, + # [num_reqs] + num_rejected: torch.Tensor, # [num_reqs + 1] query_start_loc: torch.Tensor, - # [num_reqs + 1] - cu_num_logits: torch.Tensor, ) -> None: num_reqs = idx_mapping.shape[0] _post_update_kernel[(num_reqs,)]( @@ -398,7 +391,7 @@ def post_update( sampled_tokens, sampled_tokens.stride(0), num_sampled, + num_rejected, query_start_loc, - cu_num_logits, num_warps=1, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index e0ed183d3c5b0..e34a45f979807 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import ( ) from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs from vllm.v1.worker.gpu.spec_decode import init_speculator -from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + get_num_rejected, + rejection_sample, +) from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, ) num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device) + num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) self.propose_draft( input_batch=input_batch, sampling_metadata=sampling_metadata, last_hidden_states=hidden_states, aux_hidden_states=aux_hidden_states, num_sampled=num_sampled, + num_rejected=num_rejected, ) @torch.inference_mode() @@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch: InputBatch, sampling_metadata: SamplingMetadata, grammar_output: GrammarOutput | None, - ) -> tuple[SamplerOutput, torch.Tensor]: + ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states) if grammar_output is not None: @@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # No draft tokens (common case). # 0 if chunked-prefilling, 1 if not. num_sampled = (~is_chunked_prefilling).int() + num_rejected = torch.zeros_like(num_sampled) else: # Draft tokens for spec decoding. input_ids = input_batch.input_ids[input_batch.logits_indices] @@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.num_speculative_steps, ) num_sampled *= ~is_chunked_prefilling + num_rejected = get_num_rejected( + input_batch.cu_num_logits, + num_sampled, + ) sampler_output.sampled_token_ids = sampled_tokens # TODO(woosuk): Support logprobs with spec decoding. - return sampler_output, num_sampled + return sampler_output, num_sampled, num_rejected def compute_prompt_logprobs( self, @@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch: InputBatch, sampled_tokens: torch.Tensor, num_sampled: torch.Tensor, + num_rejected: torch.Tensor, ) -> None: # Update the number of computed tokens. post_update( @@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.last_sampled_tokens, sampled_tokens, num_sampled, + num_rejected, input_batch.query_start_loc, - input_batch.cu_num_logits, ) # Update the number of computed prefill tokens. @@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): last_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, num_sampled: torch.Tensor, + num_rejected: torch.Tensor, ) -> torch.Tensor: num_reqs = input_batch.num_reqs idx_mapping_np = input_batch.idx_mapping_np @@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): last_hidden_states, aux_hidden_states, num_sampled, + num_rejected, self.req_states.last_sampled_tokens, next_prefill_tokens, ) @@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.execute_model_state = None # type: ignore assert sampling_metadata is not None - sampler_output, num_sampled_tokens = self.sample( + sampler_output, num_sampled, num_rejected = self.sample( hidden_states, input_batch, sampling_metadata, grammar_output ) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) @@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): async_output = AsyncOutput( model_runner_output=model_runner_output, sampler_output=sampler_output, - num_sampled_tokens=num_sampled_tokens, + num_sampled_tokens=num_sampled, copy_stream=self.output_copy_stream, copy_event=self.output_copy_event, ) @@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This sequencing may slightly reduce latency as async D2H copy does not # need to wait for the postprocess to finish. self.postprocess( - input_batch, sampler_output.sampled_token_ids, num_sampled_tokens + input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) if self.do_spec_decode: _ = self.propose_draft( @@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata, hidden_states, None, # aux_hidden_states - num_sampled_tokens, + num_sampled, + num_rejected, ) if self.use_async_scheduling: diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index 59d0f313d96a2..3c8621cc69c97 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -60,6 +60,8 @@ class EagleSpeculator: aux_hidden_states: list[torch.Tensor] | None, # [num_reqs] num_sampled: torch.Tensor, + # [num_reqs] + num_rejected: torch.Tensor, # [max_num_reqs, 1] last_sampled: torch.Tensor, # [num_reqs] @@ -84,6 +86,7 @@ class EagleSpeculator: self.input_ids, input_batch, num_sampled, + num_rejected, last_sampled, next_prefill_tokens, ) @@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel( last_sampled_ptr, next_prefill_tokens_ptr, num_sampled_ptr, + num_rejected_ptr, query_start_loc_ptr, - cu_num_logits_ptr, BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) @@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel( query_len = query_end - query_start # Get the true query length and next token after accounting for rejected tokens. + num_rejected = tl.load(num_rejected_ptr + batch_idx) + query_len -= num_rejected + num_sampled = tl.load(num_sampled_ptr + batch_idx) if num_sampled > 0: req_state_idx = tl.load(idx_mapping_ptr + batch_idx) next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) - - logits_start = tl.load(cu_num_logits_ptr + batch_idx) - logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1) - num_logits = logits_end - logits_start - - num_rejected = num_logits - num_sampled - query_len -= num_rejected else: # Chunked prefilling. # Get the next prefill token. @@ -182,6 +181,8 @@ def prepare_eagle_inputs( input_batch: InputBatch, # [num_reqs] num_sampled: torch.Tensor, + # [num_reqs] + num_rejected: torch.Tensor, # [max_num_reqs, 1] last_sampled: torch.Tensor, # [max_num_reqs] @@ -201,8 +202,8 @@ def prepare_eagle_inputs( last_sampled, next_prefill_tokens, num_sampled, + num_rejected, input_batch.query_start_loc, - input_batch.cu_num_logits, BLOCK_SIZE=1024, ) return last_token_indices diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 8a7bf28bacbd4..43c6ac518bccc 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -69,3 +69,15 @@ def rejection_sample( num_warps=1, ) return sampled, num_sampled + + +@torch.compile(dynamic=True) +def get_num_rejected( + cu_num_logits: torch.Tensor, + num_sampled: torch.Tensor, +) -> torch.Tensor: + num_logits = cu_num_logits[1:] - cu_num_logits[:-1] + num_rejected = num_logits - num_sampled + # No token is rejected for chunked prefills. + num_rejected *= num_sampled > 0 + return num_rejected