mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 05:07:04 +08:00
[Spec Decode] Refactor spec decoding into a separate function (#20238)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
1c50e100a9
commit
2062c0723d
@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
elif self.speculative_config.method == "ngram":
|
||||
else:
|
||||
spec_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
attn_metadata: dict[str, Any],
|
||||
) -> list[list[int]]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if max_gen_len == 1:
|
||||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
valid_sampled_token_ids):
|
||||
sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
|
||||
indices = torch.tensor(indices,
|
||||
device=sample_hidden_states.device)
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens_tensor = async_tensor_h2d(
|
||||
@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
return spec_token_ids
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def generate_draft_token_ids(
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
draft_token_ids: list[list[int]] = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user