mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:25:01 +08:00
[Minor][Spec Decode] Add use_eagle to SpeculativeConfig (#17213)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
537d5ee025
commit
1cf0719ebd
@ -2566,6 +2566,9 @@ class SpeculativeConfig:
|
|||||||
"""
|
"""
|
||||||
return self.num_speculative_tokens
|
return self.num_speculative_tokens
|
||||||
|
|
||||||
|
def use_eagle(self) -> bool:
|
||||||
|
return self.method in ("eagle", "eagle3")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
method = self.method
|
method = self.method
|
||||||
model = None if method == "ngram" else self.draft_model_config.model
|
model = None if method == "ngram" else self.draft_model_config.model
|
||||||
|
|||||||
@ -126,7 +126,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||||
if speculative_config:
|
if speculative_config:
|
||||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||||
if speculative_config.method in ("eagle", "eagle3"):
|
if speculative_config.use_eagle():
|
||||||
self.num_lookahead_tokens = self.num_spec_tokens
|
self.num_lookahead_tokens = self.num_spec_tokens
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
|
|||||||
@ -171,8 +171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
if self.speculative_config.method == "ngram":
|
if self.speculative_config.method == "ngram":
|
||||||
self.drafter = NgramProposer(self.vllm_config)
|
self.drafter = NgramProposer(self.vllm_config)
|
||||||
elif self.speculative_config.method == "eagle" or \
|
elif self.speculative_config.use_eagle():
|
||||||
self.speculative_config.method == "eagle3":
|
|
||||||
self.drafter = EagleProposer(self.vllm_config,
|
self.drafter = EagleProposer(self.vllm_config,
|
||||||
self.device) # type: ignore
|
self.device) # type: ignore
|
||||||
if self.speculative_config.method == "eagle3":
|
if self.speculative_config.method == "eagle3":
|
||||||
@ -1192,8 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert isinstance(self.drafter, NgramProposer)
|
assert isinstance(self.drafter, NgramProposer)
|
||||||
spec_token_ids = self.generate_draft_token_ids(
|
spec_token_ids = self.generate_draft_token_ids(
|
||||||
valid_sampled_token_ids, sampling_metadata)
|
valid_sampled_token_ids, sampling_metadata)
|
||||||
elif self.speculative_config.method == "eagle" or \
|
elif self.speculative_config.use_eagle():
|
||||||
self.speculative_config.method == "eagle3":
|
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
# TODO(woosuk): Refactor the loop.
|
# TODO(woosuk): Refactor the loop.
|
||||||
next_token_ids: list[int] = []
|
next_token_ids: list[int] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user