mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 12:04:27 +08:00
[Bugfix][llama4_eagle] Fix missing 'lm_head' attribute (#29926)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
parent
e23ca3a0e8
commit
962d703818
@ -402,7 +402,11 @@ def test_eagle_correctness(
|
|||||||
# Scout requires default backend selection
|
# Scout requires default backend selection
|
||||||
# because vision encoder has head_dim 88 being incompatible
|
# because vision encoder has head_dim 88 being incompatible
|
||||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||||
pass
|
|
||||||
|
# pass if not ROCm
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||||
|
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||||
else:
|
else:
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|||||||
@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
|
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM
|
from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
self.config.vocab_size, scale=logit_scale
|
self.config.vocab_size, scale=logit_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.config.draft_vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
|
||||||
# Set MoE hyperparameters
|
# Set MoE hyperparameters
|
||||||
self.set_moe_parameters()
|
self.set_moe_parameters()
|
||||||
|
|
||||||
@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
# lm_head is tied with target model (Llama4ForCausalLM)
|
# lm_head is tied with target model (Llama4ForCausalLM)
|
||||||
skip_prefixes=(["lm_head."]),
|
skip_prefixes=([]),
|
||||||
)
|
)
|
||||||
loader.load_weights(map(transform, weights))
|
loader.load_weights(map(transform, weights))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user