mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Model Runner V2] Implement get_num_sampled_and_rejected kernel (#30029)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5c32a06a04
commit
cc050558f4
@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens(
|
|||||||
return logits_indices
|
return logits_indices
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _get_num_sampled_and_rejected_kernel(
|
||||||
|
num_sampled_ptr,
|
||||||
|
num_rejected_ptr,
|
||||||
|
seq_lens_ptr,
|
||||||
|
cu_num_logits_ptr,
|
||||||
|
idx_mapping_ptr,
|
||||||
|
prefill_len_ptr,
|
||||||
|
):
|
||||||
|
batch_idx = tl.program_id(0)
|
||||||
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||||
|
|
||||||
|
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||||
|
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||||
|
is_chunked_prefilling = seq_len < prefill_len
|
||||||
|
|
||||||
|
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||||
|
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
|
||||||
|
tl.store(num_sampled_ptr + batch_idx, num_sampled)
|
||||||
|
|
||||||
|
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
|
||||||
|
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
|
||||||
|
num_logits = logits_end - logits_start
|
||||||
|
|
||||||
|
num_rejected = num_logits - num_sampled
|
||||||
|
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
|
||||||
|
tl.store(num_rejected_ptr + batch_idx, num_rejected)
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_sampled_and_rejected(
|
||||||
|
num_sampled: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
cu_num_logits: torch.Tensor,
|
||||||
|
idx_mapping: torch.Tensor,
|
||||||
|
prefill_len: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num_reqs = idx_mapping.shape[0]
|
||||||
|
num_rejected = torch.empty_like(num_sampled)
|
||||||
|
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
|
||||||
|
num_sampled,
|
||||||
|
num_rejected,
|
||||||
|
seq_lens,
|
||||||
|
cu_num_logits,
|
||||||
|
idx_mapping,
|
||||||
|
prefill_len,
|
||||||
|
)
|
||||||
|
return num_sampled, num_rejected
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _post_update_kernel(
|
def _post_update_kernel(
|
||||||
idx_mapping_ptr,
|
idx_mapping_ptr,
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
|||||||
InputBatch,
|
InputBatch,
|
||||||
InputBuffers,
|
InputBuffers,
|
||||||
combine_sampled_and_draft_tokens,
|
combine_sampled_and_draft_tokens,
|
||||||
|
get_num_sampled_and_rejected,
|
||||||
post_update,
|
post_update,
|
||||||
prepare_pos_seq_lens,
|
prepare_pos_seq_lens,
|
||||||
prepare_prefill_inputs,
|
prepare_prefill_inputs,
|
||||||
@ -54,10 +55,7 @@ from vllm.v1.worker.gpu.sample.metadata import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
|
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||||
get_num_rejected,
|
|
||||||
rejection_sample,
|
|
||||||
)
|
|
||||||
from vllm.v1.worker.gpu.states import RequestState
|
from vllm.v1.worker.gpu.states import RequestState
|
||||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||||
@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Sample tokens and compute logprobs (if needed).
|
# Sample tokens and compute logprobs (if needed).
|
||||||
sampler_output = self.sampler(logits, sampling_metadata)
|
sampler_output = self.sampler(logits, sampling_metadata)
|
||||||
|
|
||||||
# Get the number of sampled tokens.
|
|
||||||
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
|
|
||||||
is_chunked_prefilling = input_batch.seq_lens < prefill_len
|
|
||||||
if input_batch.num_draft_tokens == 0:
|
if input_batch.num_draft_tokens == 0:
|
||||||
# No draft tokens (common case).
|
# No draft tokens (common case).
|
||||||
# 0 if chunked-prefilling, 1 if not.
|
num_sampled = torch.ones(
|
||||||
num_sampled = (~is_chunked_prefilling).int()
|
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||||
num_rejected = torch.zeros_like(num_sampled)
|
)
|
||||||
else:
|
else:
|
||||||
# Draft tokens for spec decoding.
|
# Rejection sampling for spec decoding.
|
||||||
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
||||||
sampled_tokens, num_sampled = rejection_sample(
|
sampled_tokens, num_sampled = rejection_sample(
|
||||||
sampler_output.sampled_token_ids,
|
sampler_output.sampled_token_ids,
|
||||||
@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
input_batch.cu_num_logits,
|
input_batch.cu_num_logits,
|
||||||
self.num_speculative_steps,
|
self.num_speculative_steps,
|
||||||
)
|
)
|
||||||
num_sampled *= ~is_chunked_prefilling
|
|
||||||
num_rejected = get_num_rejected(
|
|
||||||
input_batch.cu_num_logits,
|
|
||||||
num_sampled,
|
|
||||||
)
|
|
||||||
sampler_output.sampled_token_ids = sampled_tokens
|
sampler_output.sampled_token_ids = sampled_tokens
|
||||||
# TODO(woosuk): Support logprobs with spec decoding.
|
|
||||||
|
# Get the number of sampled and rejected tokens.
|
||||||
|
# For chunked prefills, num_sampled and num_rejected are both 0.
|
||||||
|
num_sampled, num_rejected = get_num_sampled_and_rejected(
|
||||||
|
num_sampled,
|
||||||
|
input_batch.seq_lens,
|
||||||
|
input_batch.cu_num_logits,
|
||||||
|
input_batch.idx_mapping,
|
||||||
|
self.req_states.prefill_len.gpu,
|
||||||
|
)
|
||||||
return sampler_output, num_sampled, num_rejected
|
return sampler_output, num_sampled, num_rejected
|
||||||
|
|
||||||
def compute_prompt_logprobs(
|
def compute_prompt_logprobs(
|
||||||
|
|||||||
@ -69,15 +69,3 @@ def rejection_sample(
|
|||||||
num_warps=1,
|
num_warps=1,
|
||||||
)
|
)
|
||||||
return sampled, num_sampled
|
return sampled, num_sampled
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
|
||||||
def get_num_rejected(
|
|
||||||
cu_num_logits: torch.Tensor,
|
|
||||||
num_sampled: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
|
|
||||||
num_rejected = num_logits - num_sampled
|
|
||||||
# No token is rejected for chunked prefills.
|
|
||||||
num_rejected *= num_sampled > 0
|
|
||||||
return num_rejected
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user