[V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (#18290)

This commit is contained in:
Ekagra Ranjan 2025-05-22 18:17:33 -04:00 committed by GitHub
parent 7b9d832c80
commit a04720bc36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 3 deletions

View File

@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
assert self.model is not None, \
"model should not be None when method is eagle"
kwargs["architectures"] = [
f"Eagle{arch}" for arch in self.model.architectures
f"Eagle{arch}" if not arch.startswith("Eagle") \
else arch for arch in self.model.architectures
]
elif method == "eagle3":
assert self.model is not None, \
"model should not be None when method is eagle3"
kwargs["architectures"] = [
f"Eagle3{arch}" for arch in self.model.architectures
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
else arch for arch in self.model.architectures
]
else:
raise ValueError(f"Invalid method {method}. \

View File

@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
@ -308,6 +309,9 @@ class EagleProposer:
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
assert "model.embed_tokens.weight" not in loaded_weights, \