From c68698b3264cbb05011a020ae654d38d155d9dcd Mon Sep 17 00:00:00 2001 From: qizixi <22851944+zixi-qi@users.noreply.github.com> Date: Thu, 12 Jun 2025 20:09:19 -0700 Subject: [PATCH] [Bugfix] Fix EAGLE vocab embedding for multimodal target model (#19570) Signed-off-by: qizixi --- vllm/v1/spec_decode/eagle.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b5c9b7ec640e..f7179385ebb74 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -329,16 +329,24 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) + if supports_multimodal(target_model): + # handle multimodality + self.model.config.image_token_index = ( + target_model.config.image_token_index) + target_language_model = target_model.get_language_model() + else: + target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ and self.model.model.embed_tokens.weight.shape \ - == target_model.model.embed_tokens.weight.shape: + == target_language_model.model.embed_tokens.weight.shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" \ " with the target model." ) del self.model.model.embed_tokens - self.model.model.embed_tokens = target_model.model.embed_tokens + self.model.model.embed_tokens = ( + target_language_model.model.embed_tokens) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" \ @@ -349,12 +357,9 @@ class EagleProposer: # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_model, "lm_head"): + hasattr(target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(target_model): - self.model.lm_head = target_model.get_language_model().lm_head - else: - self.model.lm_head = target_model.lm_head + self.model.lm_head = target_language_model.lm_head @torch.inference_mode() def dummy_run(