mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 12:16:13 +08:00
[V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (#18290)
This commit is contained in:
parent
7b9d832c80
commit
a04720bc36
@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle"
|
"model should not be None when method is eagle"
|
||||||
kwargs["architectures"] = [
|
kwargs["architectures"] = [
|
||||||
f"Eagle{arch}" for arch in self.model.architectures
|
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
||||||
|
else arch for arch in self.model.architectures
|
||||||
]
|
]
|
||||||
elif method == "eagle3":
|
elif method == "eagle3":
|
||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle3"
|
"model should not be None when method is eagle3"
|
||||||
kwargs["architectures"] = [
|
kwargs["architectures"] = [
|
||||||
f"Eagle3{arch}" for arch in self.model.architectures
|
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
|
||||||
|
else arch for arch in self.model.architectures
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid method {method}. \
|
raise ValueError(f"Invalid method {method}. \
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import get_pp_group
|
|||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import (
|
||||||
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
@ -308,6 +309,9 @@ class EagleProposer:
|
|||||||
loaded_weights = self.model.load_weights(
|
loaded_weights = self.model.load_weights(
|
||||||
loader.get_all_weights(draft_model_config, self.model))
|
loader.get_all_weights(draft_model_config, self.model))
|
||||||
|
|
||||||
|
process_weights_after_loading(self.model, draft_model_config,
|
||||||
|
target_device)
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
assert "model.embed_tokens.weight" not in loaded_weights, \
|
assert "model.embed_tokens.weight" not in loaded_weights, \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user