Add support for Eagle with separate lm-head and embed_tokens layers (#28549)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić 2025-11-15 15:12:02 +01:00 committed by GitHub
parent 085a525332
commit e439c784fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 205 additions and 64 deletions

View File

@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
@ -335,6 +336,7 @@ def test_load_model(
attn_backend,
pp_size,
use_distinct_embed_tokens,
use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
@ -350,12 +352,13 @@ def test_load_model(
# Setup draft model mock
mock_model = mock.MagicMock()
mock_model.model = mock.MagicMock()
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
if use_distinct_embed_tokens:
# Some models can have a different hidden size than the target model,
# so we test that their embed_tokens doesn't get overwritten
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
else:
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_model.model.embed_tokens = mock.MagicMock()
mock_model.has_own_lm_head = use_distinct_lm_head
if use_distinct_lm_head:
mock_model.lm_head = mock.MagicMock()
mock_get_model.return_value = mock_model
@ -391,15 +394,13 @@ def test_load_model(
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()
target_model.model.embed_tokens = mock.MagicMock()
from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)
if method == "eagle":
target_model.lm_head = mock.MagicMock()
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
@ -409,18 +410,18 @@ def test_load_model(
# Verify common interactions
mock_get_model.assert_called_once()
# Verify that EAGLE models gain the lm head from the target model
if method == "eagle":
assert proposer.model.lm_head == target_model.lm_head
# Verify that the lm head is set correctly
if use_distinct_lm_head:
assert proposer.model.lm_head is not target_model.lm_head
else:
assert proposer.model.lm_head is target_model.lm_head
# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
else:
# When pp_size is 1 and the draft and target models have
# embed_tokens of the same shape, they should be shared.
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
@pytest.mark.parametrize("method", ["eagle", "eagle3"])

View File

@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
# MTP does not have its own embed_tokens or lm_head
# so it should share them with the target model
mock_model.has_own_embed_tokens = False
mock_model.has_own_lm_head = False
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}

View File

@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import (
)
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(

View File

@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -932,13 +932,73 @@ def supports_transcription(
@runtime_checkable
class SupportsEagle3(Protocol):
class SupportsEagleBase(Protocol):
"""Base interface for models that support EAGLE-based speculative decoding."""
has_own_lm_head: bool = False
"""
A flag that indicates this model has trained its own lm_head.
"""
has_own_embed_tokens: bool = False
"""
A flag that indicates this model has trained its own input embeddings.
"""
@overload
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
@overload
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
def supports_any_eagle(
model: type[object] | object,
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
"""Check if model supports any EAGLE variant (1, 2, or 3)."""
return supports_eagle(model) or supports_eagle3(model)
@runtime_checkable
class SupportsEagle(SupportsEagleBase, Protocol):
"""The interface required for models that support
EAGLE3 speculative decoding."""
EAGLE-1 and EAGLE-2 speculative decoding."""
supports_eagle: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE-1 and EAGLE-2
speculative decoding.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
@overload
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
def supports_eagle(
model: type[object] | object,
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
return isinstance(model, SupportsEagle)
@runtime_checkable
class SupportsEagle3(SupportsEagleBase, Protocol):
"""The interface required for models that support
EAGLE-3 speculative decoding."""
supports_eagle3: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE3
A flag that indicates this model supports EAGLE-3
speculative decoding.
Note:
@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
hidden states for EAGLE3.
hidden states for EAGLE-3.
Args:
layers: Tuple of layer indices that should output auxiliary
@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol):
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE3.
for EAGLE-3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.

View File

@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@ -529,7 +529,9 @@ class LlamaModel(nn.Module):
return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class LlamaForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa
from vllm.model_executor.models.utils import extract_layer_index
from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, weight
loader = AutoWeightsLoader(

View File

@ -17,7 +17,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@ -179,6 +179,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(

View File

@ -23,7 +23,7 @@ from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@ -324,6 +324,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:

View File

@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
@ -52,6 +52,7 @@ from .utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
process_eagle_weight,
)
@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module):
return loaded_params
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def transform(inputs):
name, loaded_weight = inputs
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
return loader.load_weights(map(transform, weights))

View File

@ -19,6 +19,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@ -825,3 +826,25 @@ direct_register_custom_op(
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def process_eagle_weight(
model: nn.Module,
name: str,
) -> None:
"""
Update EAGLE model flags based on loaded weight name.
This should be called during weight loading to detect if a model
has its own lm_head or embed_tokens weight.
Args:
model: The model instance (must support EAGLE)
name: The name of the weight to process
"""
if not supports_any_eagle(model):
return
# To prevent overriding with target model's layers
if "lm_head" in name:
model.has_own_lm_head = True
if "embed_tokens" in name:
model.has_own_embed_tokens = True

View File

@ -991,6 +991,7 @@ class EagleProposer:
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"):
@ -1002,52 +1003,92 @@ class EagleProposer:
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
# Check if shapes match and we found the embedding
eagle_shape = self.model.model.embed_tokens.weight.shape
target_shape = target_embed_tokens.weight.shape
if eagle_shape == target_shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
share_embeddings = False
if hasattr(self.model, "has_own_embed_tokens"):
# EAGLE model
if not self.model.has_own_embed_tokens:
share_embeddings = True
logger.info(
"Detected EAGLE model without its own embed_tokens in the"
" checkpoint. Sharing target model embedding weights with the"
" draft model."
)
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.equal(
target_embed_tokens.weight, self.model.model.embed_tokens.weight
)
):
share_embeddings = True
logger.info(
"Detected EAGLE model with embed_tokens identical to the target"
" model. Sharing target model embedding weights with the draft"
" model."
)
else:
logger.info(
"Detected EAGLE model with distinct embed_tokens weights. "
"Keeping separate embedding weights from the target model."
)
else:
# MTP model
share_embeddings = True
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model."
"Detected MTP model. "
"Sharing target model embedding weights with the draft model."
)
if share_embeddings:
if hasattr(self.model.model, "embed_tokens"):
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
"The draft model's vocab embedding will be loaded separately"
" from the target model."
)
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3":
if hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (
hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape
):
share_lm_head = False
if hasattr(self.model, "has_own_lm_head"):
# EAGLE model
if not self.model.has_own_lm_head:
share_lm_head = True
logger.info(
"Assuming the EAGLE head shares the same lm_head"
" with the target model."
"Detected EAGLE model without its own lm_head in the checkpoint. "
"Sharing target model lm_head weights with the draft model."
)
elif (
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
and torch.equal(
target_language_model.lm_head.weight, self.model.lm_head.weight
)
):
share_lm_head = True
logger.info(
"Detected EAGLE model with lm_head identical to the target model. "
"Sharing target model lm_head weights with the draft model."
)
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model."
"Detected EAGLE model with distinct lm_head weights. "
"Keeping separate lm_head weights from the target model."
)
else:
# MTP model
share_lm_head = True
logger.info(
"Detected MTP model. "
"Sharing target model lm_head weights with the draft model."
)
if share_lm_head and hasattr(target_language_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
@torch.inference_mode()
def dummy_run(