diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 5246ea6517f6c..575a6a151f579 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -402,7 +402,11 @@ def test_eagle_correctness( # Scout requires default backend selection # because vision encoder has head_dim 88 being incompatible # with FLASH_ATTN and needs to fall back to Flex Attn - pass + + # pass if not ROCm + if current_platform.is_rocm(): + # TODO: Enable Flex Attn for spec_decode on ROCm + pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") else: m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 0146b30579287..02f5b5ff639bd 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.torchao import TorchAOConfig -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM from vllm.model_executor.models.utils import extract_layer_index @@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): self.config.vocab_size, scale=logit_scale ) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + # Set MoE hyperparameters self.set_moe_parameters() @@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) - skip_prefixes=(["lm_head."]), + skip_prefixes=([]), ) loader.load_weights(map(transform, weights))