[Model Runner V2] Implement get_num_sampled_and_rejected kernel (#30029)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-12-04 07:19:42 -08:00 committed by GitHub
parent 5c32a06a04
commit cc050558f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 29 deletions

View File

@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens(
return logits_indices
@triton.jit
def _get_num_sampled_and_rejected_kernel(
num_sampled_ptr,
num_rejected_ptr,
seq_lens_ptr,
cu_num_logits_ptr,
idx_mapping_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
is_chunked_prefilling = seq_len < prefill_len
num_sampled = tl.load(num_sampled_ptr + batch_idx)
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
tl.store(num_sampled_ptr + batch_idx, num_sampled)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
tl.store(num_rejected_ptr + batch_idx, num_rejected)
def get_num_sampled_and_rejected(
num_sampled: torch.Tensor,
seq_lens: torch.Tensor,
cu_num_logits: torch.Tensor,
idx_mapping: torch.Tensor,
prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0]
num_rejected = torch.empty_like(num_sampled)
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
num_sampled,
num_rejected,
seq_lens,
cu_num_logits,
idx_mapping,
prefill_len,
)
return num_sampled, num_rejected
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,

View File

@ -43,6 +43,7 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_sampled_and_draft_tokens,
get_num_sampled_and_rejected,
post_update,
prepare_pos_seq_lens,
prepare_prefill_inputs,
@ -54,10 +55,7 @@ from vllm.v1.worker.gpu.sample.metadata import (
)
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
get_num_rejected,
rejection_sample,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
# Get the number of sampled tokens.
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefilling = input_batch.seq_lens < prefill_len
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int()
num_rejected = torch.zeros_like(num_sampled)
num_sampled = torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
else:
# Draft tokens for spec decoding.
# Rejection sampling for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.cu_num_logits,
self.num_speculative_steps,
)
num_sampled *= ~is_chunked_prefilling
num_rejected = get_num_rejected(
input_batch.cu_num_logits,
num_sampled,
)
sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected(
num_sampled,
input_batch.seq_lens,
input_batch.cu_num_logits,
input_batch.idx_mapping,
self.req_states.prefill_len.gpu,
)
return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs(

View File

@ -69,15 +69,3 @@ def rejection_sample(
num_warps=1,
)
return sampled, num_sampled
@torch.compile(dynamic=True)
def get_num_rejected(
cu_num_logits: torch.Tensor,
num_sampled: torch.Tensor,
) -> torch.Tensor:
num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
num_rejected = num_logits - num_sampled
# No token is rejected for chunked prefills.
num_rejected *= num_sampled > 0
return num_rejected