diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 31e3172c61ebf..377523efefc30 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -70,7 +70,8 @@ class EAGLEConfig(PretrainedConfig): if self.model is not None: for k, v in self.model.to_dict().items(): - setattr(self, k, v) + if not hasattr(self, k): + setattr(self, k, v) @classmethod def from_pretrained( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 671b98544387a..460d645a1a6c7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,6 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -311,10 +310,7 @@ class EagleProposer: if self.vllm_config.speculative_config.method != "eagle3" and \ hasattr(target_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(target_model): - self.model.lm_head = target_model.get_language_model().lm_head - else: - self.model.lm_head = target_model.lm_head + self.model.lm_head = target_model.lm_head @torch.inference_mode() def dummy_run(