mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 07:35:01 +08:00
[Spec Decode] Make EAGLE3 draft token ID mapping optional (#18488)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
e44d8ce8c7
commit
583507d130
@ -214,6 +214,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
if self.draft_id_to_target_id is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
||||||
targets = base + self.draft_id_to_target_id
|
targets = base + self.draft_id_to_target_id
|
||||||
logits_new = logits.new_full((
|
logits_new = logits.new_full((
|
||||||
@ -246,4 +249,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
name = "model." + name
|
name = "model." + name
|
||||||
model_weights[name] = loaded_weight
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
return loader.load_weights(model_weights.items())
|
loaded_weights = loader.load_weights(model_weights.items())
|
||||||
|
|
||||||
|
if 'd2t' not in loaded_weights:
|
||||||
|
self.draft_id_to_target_id = None
|
||||||
|
|
||||||
|
return loaded_weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user