mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 15:46:51 +08:00
[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
0e98964e94
commit
a09c7ca9f2
@ -146,31 +146,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
|
||||
# NOTE(Jiayi): currently we put the entire draft model on
|
||||
# the last PP rank. This is not ideal if there are many
|
||||
# layers in the draft model.
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device,
|
||||
self) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device) # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
# Set up speculative decoding.
|
||||
# NOTE(Jiayi): currently we put the entire draft model on
|
||||
# the last PP rank. This is not ideal if there are many
|
||||
# layers in the draft model.
|
||||
if self.speculative_config and get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device,
|
||||
self) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device) # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
if not self.use_spec_decode:
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
elif self.speculative_config.method == "ngram":
|
||||
@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.use_spec_decode and self.speculative_config.use_eagle():
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
if self.use_spec_decode:
|
||||
if self.speculative_config:
|
||||
draft_token_ids = [[0] for _ in range(num_reqs)]
|
||||
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids, self.device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user