[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2025-05-28 14:57:19 -04:00 committed by GitHub
parent 0e98964e94
commit a09c7ca9f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -146,31 +146,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
self.use_aux_hidden_state_outputs = False
if self.speculative_config:
self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.use_spec_decode:
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
elif self.speculative_config.method == "ngram":
@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
hidden_states = outputs
if self.use_spec_decode and self.speculative_config.use_eagle():
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e
else:
raise e
if self.use_spec_decode:
if self.speculative_config:
draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device)