mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:54:29 +08:00
[Model Runner V2] Simplify Eagle bookkeeping with num_rejected (#29347)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
3cfa63ad99
commit
f32c7d6f54
@ -344,8 +344,8 @@ def _post_update_kernel(
|
||||
sampled_tokens_ptr,
|
||||
sampled_tokens_stride,
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
cu_num_logits_ptr,
|
||||
):
|
||||
req_id = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_id)
|
||||
@ -360,17 +360,10 @@ def _post_update_kernel(
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = query_end - query_start
|
||||
num_rejected = tl.load(num_rejected_ptr + req_id)
|
||||
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
num_computed += query_len
|
||||
# Consider the rejected tokens in spec decoding.
|
||||
if num_sampled > 0:
|
||||
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
|
||||
logits_start = tl.load(cu_num_logits_ptr + req_id)
|
||||
logits_end = tl.load(cu_num_logits_ptr + req_id + 1)
|
||||
num_logits = logits_end - logits_start
|
||||
num_rejected = num_logits - num_sampled
|
||||
num_computed -= num_rejected
|
||||
num_computed += query_len - num_rejected
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
|
||||
|
||||
|
||||
@ -385,10 +378,10 @@ def post_update(
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_kernel[(num_reqs,)](
|
||||
@ -398,7 +391,7 @@ def post_update(
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
query_start_loc,
|
||||
cu_num_logits,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
|
||||
get_num_rejected,
|
||||
rejection_sample,
|
||||
)
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
device=self.device,
|
||||
)
|
||||
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
|
||||
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
|
||||
self.propose_draft(
|
||||
input_batch=input_batch,
|
||||
sampling_metadata=sampling_metadata,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=num_sampled,
|
||||
num_rejected=num_rejected,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
grammar_output: GrammarOutput | None,
|
||||
) -> tuple[SamplerOutput, torch.Tensor]:
|
||||
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if grammar_output is not None:
|
||||
@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# No draft tokens (common case).
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
num_sampled = (~is_chunked_prefilling).int()
|
||||
num_rejected = torch.zeros_like(num_sampled)
|
||||
else:
|
||||
# Draft tokens for spec decoding.
|
||||
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
||||
@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
num_sampled *= ~is_chunked_prefilling
|
||||
num_rejected = get_num_rejected(
|
||||
input_batch.cu_num_logits,
|
||||
num_sampled,
|
||||
)
|
||||
sampler_output.sampled_token_ids = sampled_tokens
|
||||
# TODO(woosuk): Support logprobs with spec decoding.
|
||||
return sampler_output, num_sampled
|
||||
return sampler_output, num_sampled, num_rejected
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch: InputBatch,
|
||||
sampled_tokens: torch.Tensor,
|
||||
num_sampled: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
) -> None:
|
||||
# Update the number of computed tokens.
|
||||
post_update(
|
||||
@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.req_states.last_sampled_tokens,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.cu_num_logits,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
last_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
num_sampled: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = input_batch.num_reqs
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
last_hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.execute_model_state = None # type: ignore
|
||||
assert sampling_metadata is not None
|
||||
|
||||
sampler_output, num_sampled_tokens = self.sample(
|
||||
sampler_output, num_sampled, num_rejected = self.sample(
|
||||
hidden_states, input_batch, sampling_metadata, grammar_output
|
||||
)
|
||||
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
|
||||
@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
async_output = AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampler_output=sampler_output,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
num_sampled_tokens=num_sampled,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This sequencing may slightly reduce latency as async D2H copy does not
|
||||
# need to wait for the postprocess to finish.
|
||||
self.postprocess(
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
_ = self.propose_draft(
|
||||
@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
None, # aux_hidden_states
|
||||
num_sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
)
|
||||
|
||||
if self.use_async_scheduling:
|
||||
|
||||
@ -60,6 +60,8 @@ class EagleSpeculator:
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [max_num_reqs, 1]
|
||||
last_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
@ -84,6 +86,7 @@ class EagleSpeculator:
|
||||
self.input_ids,
|
||||
input_batch,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel(
|
||||
last_sampled_ptr,
|
||||
next_prefill_tokens_ptr,
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel(
|
||||
query_len = query_end - query_start
|
||||
|
||||
# Get the true query length and next token after accounting for rejected tokens.
|
||||
num_rejected = tl.load(num_rejected_ptr + batch_idx)
|
||||
query_len -= num_rejected
|
||||
|
||||
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||
if num_sampled > 0:
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
|
||||
|
||||
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
|
||||
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
|
||||
num_logits = logits_end - logits_start
|
||||
|
||||
num_rejected = num_logits - num_sampled
|
||||
query_len -= num_rejected
|
||||
else:
|
||||
# Chunked prefilling.
|
||||
# Get the next prefill token.
|
||||
@ -182,6 +181,8 @@ def prepare_eagle_inputs(
|
||||
input_batch: InputBatch,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [max_num_reqs, 1]
|
||||
last_sampled: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
@ -201,8 +202,8 @@ def prepare_eagle_inputs(
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.cu_num_logits,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return last_token_indices
|
||||
|
||||
@ -69,3 +69,15 @@ def rejection_sample(
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_num_rejected(
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_sampled: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
|
||||
num_rejected = num_logits - num_sampled
|
||||
# No token is rejected for chunked prefills.
|
||||
num_rejected *= num_sampled > 0
|
||||
return num_rejected
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user