From 2062c0723d38a8f4a8a7565b61a99e8c81b5cacd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:13:50 -0700 Subject: [PATCH] [Spec Decode] Refactor spec decoding into a separate function (#20238) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 93 +++++++++++++++++++----------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 290b9a44a80e2..e063e44dabfa1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, aux_hidden_states = model_output else: hidden_states = model_output + aux_hidden_states = None + # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches @@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None - elif self.speculative_config.method == "ngram": + else: + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + attn_metadata, + ) + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + self.eplb_step() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, + num_nans_in_logits=num_nans_in_logits, + ) + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + attn_metadata: dict[str, Any], + ) -> list[list[int]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self.propose_ngram_draft_token_ids( + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) - if max_gen_len == 1: + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, - valid_sampled_token_ids): + sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 - - indices = torch.tensor(indices, - device=sample_hidden_states.device) + indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] spec_token_ids = self.drafter.propose( @@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): + for i, token_ids in enumerate(sampled_token_ids): if token_ids: # Common case. next_token_id = token_ids[-1] @@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens_tensor = async_tensor_h2d( @@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens, ) target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - - self.eplb_step() - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, - ) + return spec_token_ids def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: @@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def generate_draft_token_ids( + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = []