From cc050558f424714f9548774cc2c661b3916d96ca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 4 Dec 2025 07:19:42 -0800 Subject: [PATCH] [Model Runner V2] Implement get_num_sampled_and_rejected kernel (#30029) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 49 +++++++++++++++++++ vllm/v1/worker/gpu/model_runner.py | 33 ++++++------- .../gpu/spec_decode/rejection_sample.py | 12 ----- 3 files changed, 65 insertions(+), 29 deletions(-) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 8ae887fe82cf..1b78734fba78 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens( return logits_indices +@triton.jit +def _get_num_sampled_and_rejected_kernel( + num_sampled_ptr, + num_rejected_ptr, + seq_lens_ptr, + cu_num_logits_ptr, + idx_mapping_ptr, + prefill_len_ptr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + prefill_len = tl.load(prefill_len_ptr + req_state_idx) + is_chunked_prefilling = seq_len < prefill_len + + num_sampled = tl.load(num_sampled_ptr + batch_idx) + num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled) + tl.store(num_sampled_ptr + batch_idx, num_sampled) + + logits_start = tl.load(cu_num_logits_ptr + batch_idx) + logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1) + num_logits = logits_end - logits_start + + num_rejected = num_logits - num_sampled + num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected) + tl.store(num_rejected_ptr + batch_idx, num_rejected) + + +def get_num_sampled_and_rejected( + num_sampled: torch.Tensor, + seq_lens: torch.Tensor, + cu_num_logits: torch.Tensor, + idx_mapping: torch.Tensor, + prefill_len: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = idx_mapping.shape[0] + num_rejected = torch.empty_like(num_sampled) + _get_num_sampled_and_rejected_kernel[(num_reqs,)]( + num_sampled, + num_rejected, + seq_lens, + cu_num_logits, + idx_mapping, + prefill_len, + ) + return num_sampled, num_rejected + + @triton.jit def _post_update_kernel( idx_mapping_ptr, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9bf345053c30..464f7b7bd353 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -43,6 +43,7 @@ from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, combine_sampled_and_draft_tokens, + get_num_sampled_and_rejected, post_update, prepare_pos_seq_lens, prepare_prefill_inputs, @@ -54,10 +55,7 @@ from vllm.v1.worker.gpu.sample.metadata import ( ) from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator -from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( - get_num_rejected, - rejection_sample, -) +from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sample tokens and compute logprobs (if needed). sampler_output = self.sampler(logits, sampling_metadata) - # Get the number of sampled tokens. - prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping] - is_chunked_prefilling = input_batch.seq_lens < prefill_len if input_batch.num_draft_tokens == 0: # No draft tokens (common case). - # 0 if chunked-prefilling, 1 if not. - num_sampled = (~is_chunked_prefilling).int() - num_rejected = torch.zeros_like(num_sampled) + num_sampled = torch.ones( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ) else: - # Draft tokens for spec decoding. + # Rejection sampling for spec decoding. input_ids = input_batch.input_ids[input_batch.logits_indices] sampled_tokens, num_sampled = rejection_sample( sampler_output.sampled_token_ids, @@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.cu_num_logits, self.num_speculative_steps, ) - num_sampled *= ~is_chunked_prefilling - num_rejected = get_num_rejected( - input_batch.cu_num_logits, - num_sampled, - ) sampler_output.sampled_token_ids = sampled_tokens - # TODO(woosuk): Support logprobs with spec decoding. + + # Get the number of sampled and rejected tokens. + # For chunked prefills, num_sampled and num_rejected are both 0. + num_sampled, num_rejected = get_num_sampled_and_rejected( + num_sampled, + input_batch.seq_lens, + input_batch.cu_num_logits, + input_batch.idx_mapping, + self.req_states.prefill_len.gpu, + ) return sampler_output, num_sampled, num_rejected def compute_prompt_logprobs( diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 43c6ac518bcc..8a7bf28bacbd 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -69,15 +69,3 @@ def rejection_sample( num_warps=1, ) return sampled, num_sampled - - -@torch.compile(dynamic=True) -def get_num_rejected( - cu_num_logits: torch.Tensor, - num_sampled: torch.Tensor, -) -> torch.Tensor: - num_logits = cu_num_logits[1:] - cu_num_logits[:-1] - num_rejected = num_logits - num_sampled - # No token is rejected for chunked prefills. - num_rejected *= num_sampled > 0 - return num_rejected