[Spec Decode] Refactor spec decoding into a separate function (#20238)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-06-30 08:13:50 -07:00 committed by GitHub
parent 1c50e100a9
commit 2062c0723d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
elif self.speculative_config.method == "ngram":
else:
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
spec_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if max_gen_len == 1:
if sample_hidden_states.shape[0] == len(sampled_token_ids):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
valid_sampled_token_ids):
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices,
device=sample_hidden_states.device)
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_tensor = async_tensor_h2d(
@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)
return spec_token_ids
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.finished_req_ids)
return None, None
def generate_draft_token_ids(
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []