mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Minor][Spec Decode] Add use_eagle to SpeculativeConfig (#17213)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
537d5ee025
commit
1cf0719ebd
@ -2566,6 +2566,9 @@ class SpeculativeConfig:
|
||||
"""
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method == "ngram" else self.draft_model_config.model
|
||||
|
||||
@ -126,7 +126,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||
if speculative_config:
|
||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||
if speculative_config.method in ("eagle", "eagle3"):
|
||||
if speculative_config.use_eagle():
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
|
||||
@ -171,8 +171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.method == "eagle" or \
|
||||
self.speculative_config.method == "eagle3":
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
@ -1192,8 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
elif self.speculative_config.method == "eagle" or \
|
||||
self.speculative_config.method == "eagle3":
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
next_token_ids: list[int] = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user