diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 910c0e80bb31c..5d5558162ab37 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -146,31 +146,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - # Set up speculative decoding. - self.use_spec_decode = False self.use_aux_hidden_state_outputs = False - if self.speculative_config: - self.use_spec_decode = True - - # NOTE(Jiayi): currently we put the entire draft model on - # the last PP rank. This is not ideal if there are many - # layers in the draft model. - if get_pp_group().is_last_rank: - if self.speculative_config.method == "ngram": - self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore - if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True - elif self.speculative_config.method == "medusa": - self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore - else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") - self.rejection_sampler = RejectionSampler() + # Set up speculative decoding. + # NOTE(Jiayi): currently we put the entire draft model on + # the last PP rank. This is not ideal if there are many + # layers in the draft model. + if self.speculative_config and get_pp_group().is_last_rank: + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.use_eagle(): + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "medusa": + self.drafter = MedusaProposer( + vllm_config=self.vllm_config, + device=self.device) # type: ignore + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") + self.rejection_sampler = RejectionSampler() # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() - if not self.use_spec_decode: + if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None elif self.speculative_config.method == "ngram": @@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: hidden_states = outputs - if self.use_spec_decode and self.speculative_config.use_eagle(): + if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) @@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): "initializing the engine.") from e else: raise e - if self.use_spec_decode: + if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids, self.device)