diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 377523efefc3..31e3172c61eb 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -70,8 +70,7 @@ class EAGLEConfig(PretrainedConfig): if self.model is not None: for k, v in self.model.to_dict().items(): - if not hasattr(self, k): - setattr(self, k, v) + setattr(self, k, v) @classmethod def from_pretrained( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 460d645a1a6c..671b98544387 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -310,7 +311,10 @@ class EagleProposer: if self.vllm_config.speculative_config.method != "eagle3" and \ hasattr(target_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_model.lm_head + if supports_multimodal(target_model): + self.model.lm_head = target_model.get_language_model().lm_head + else: + self.model.lm_head = target_model.lm_head @torch.inference_mode() def dummy_run(