From a04720bc36401d831cb048c3917b9e58173d9c1d Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 22 May 2025 18:17:33 -0400 Subject: [PATCH] [V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (#18290) --- vllm/transformers_utils/configs/eagle.py | 6 ++++-- vllm/v1/spec_decode/eagle.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 586d5c7f5e54b..377523efefc30 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig): assert self.model is not None, \ "model should not be None when method is eagle" kwargs["architectures"] = [ - f"Eagle{arch}" for arch in self.model.architectures + f"Eagle{arch}" if not arch.startswith("Eagle") \ + else arch for arch in self.model.architectures ] elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3" kwargs["architectures"] = [ - f"Eagle3{arch}" for arch in self.model.architectures + f"Eagle3{arch}" if not arch.startswith("Eagle3") \ + else arch for arch in self.model.architectures ] else: raise ValueError(f"Invalid method {method}. \ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5b84bc1f5ec39..19fb2a2af7ddc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton @@ -308,6 +309,9 @@ class EagleProposer: loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: assert "model.embed_tokens.weight" not in loaded_weights, \