mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:15:02 +08:00
[Bugfix] Fix EAGLE3 broken logits (#18909)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
bbfa0c61d1
commit
1bc86a3da1
@ -215,6 +215,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
if self.draft_id_to_target_id is None:
|
if self.draft_id_to_target_id is None:
|
||||||
|
assert logits.shape[1] == self.config.vocab_size, \
|
||||||
|
"Expected logits to have shape " \
|
||||||
|
f"(*, {self.config.vocab_size}), but got {logits.shape}"
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
||||||
@ -234,24 +237,22 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
return self.model.fc(hidden_states)
|
return self.model.fc(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(
|
|
||||||
self,
|
|
||||||
skip_prefixes=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_weights = {}
|
model_weights = {}
|
||||||
|
includes_draft_id_mapping = False
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "t2d" in name:
|
if "t2d" in name:
|
||||||
continue
|
continue
|
||||||
if "d2t" in name:
|
if "d2t" in name:
|
||||||
name = name.replace("d2t", "draft_id_to_target_id")
|
name = name.replace("d2t", "draft_id_to_target_id")
|
||||||
|
includes_draft_id_mapping = True
|
||||||
elif "lm_head" not in name:
|
elif "lm_head" not in name:
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
model_weights[name] = loaded_weight
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
loaded_weights = loader.load_weights(model_weights.items())
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
if 'd2t' not in loaded_weights:
|
skip_prefixes=None,
|
||||||
self.draft_id_to_target_id = None
|
skip_substrs=["draft_id_to_target_id"] \
|
||||||
|
if not includes_draft_id_mapping else None,
|
||||||
return loaded_weights
|
)
|
||||||
|
loader.load_weights(model_weights.items())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user