diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 421da5241555..805b8c86b080 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -324,6 +324,7 @@ def test_prepare_inputs_padded(): @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) +@pytest.mark.parametrize("use_distinct_lm_head", [True, False]) @mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") @mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") @mock.patch("vllm.v1.spec_decode.eagle.get_model") @@ -335,6 +336,7 @@ def test_load_model( attn_backend, pp_size, use_distinct_embed_tokens, + use_distinct_lm_head, monkeypatch, ): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) @@ -350,12 +352,13 @@ def test_load_model( # Setup draft model mock mock_model = mock.MagicMock() + mock_model.model = mock.MagicMock() + mock_model.has_own_embed_tokens = use_distinct_embed_tokens if use_distinct_embed_tokens: - # Some models can have a different hidden size than the target model, - # so we test that their embed_tokens doesn't get overwritten - mock_model.model.embed_tokens.weight.shape = (131072, 2048) - else: - mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_model.model.embed_tokens = mock.MagicMock() + mock_model.has_own_lm_head = use_distinct_lm_head + if use_distinct_lm_head: + mock_model.lm_head = mock.MagicMock() mock_get_model.return_value = mock_model @@ -391,15 +394,13 @@ def test_load_model( target_model = mock.create_autospec(_TargetModelStub, instance=True) target_model.model = mock.MagicMock() - target_model.model.embed_tokens.weight.shape = (131072, 4096) + target_model.lm_head = mock.MagicMock() + target_model.model.embed_tokens = mock.MagicMock() from vllm.model_executor.models import SupportsMultiModal assert not isinstance(target_model, SupportsMultiModal) - if method == "eagle": - target_model.lm_head = mock.MagicMock() - # Create proposer using the helper function proposer = _create_proposer(method, num_speculative_tokens=8) @@ -409,18 +410,18 @@ def test_load_model( # Verify common interactions mock_get_model.assert_called_once() - # Verify that EAGLE models gain the lm head from the target model - if method == "eagle": - assert proposer.model.lm_head == target_model.lm_head + # Verify that the lm head is set correctly + if use_distinct_lm_head: + assert proposer.model.lm_head is not target_model.lm_head + else: + assert proposer.model.lm_head is target_model.lm_head # Verify that the embed tokens are set correctly # If pp_size is > 1, the embed tokens should be distinct if pp_size > 1 or use_distinct_embed_tokens: - assert proposer.model.model.embed_tokens != target_model.model.embed_tokens + assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens else: - # When pp_size is 1 and the draft and target models have - # embed_tokens of the same shape, they should be shared. - assert proposer.model.model.embed_tokens == target_model.model.embed_tokens + assert proposer.model.model.embed_tokens is target_model.model.embed_tokens @pytest.mark.parametrize("method", ["eagle", "eagle3"]) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 6d59b58e739e..c5c0491abaf7 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro mock_model = mock.MagicMock() mock_model.model.embed_tokens.weight.shape = (131072, 4096) mock_get_model.return_value = mock_model + # MTP does not have its own embed_tokens or lm_head + # so it should share them with the target model + mock_model.has_own_embed_tokens = False + mock_model.has_own_lm_head = False target_attn_layers = {"target_attn_1": mock.MagicMock()} all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 9e834a73f8e5..3fb04c3b70dd 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import ( ) from vllm.utils import init_logger -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 115818d903a6..e8ee9951d611 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import ( ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts): class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 929bfaaee5cb..dc4caf2f02f9 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -932,13 +932,73 @@ def supports_transcription( @runtime_checkable -class SupportsEagle3(Protocol): +class SupportsEagleBase(Protocol): + """Base interface for models that support EAGLE-based speculative decoding.""" + + has_own_lm_head: bool = False + """ + A flag that indicates this model has trained its own lm_head. + """ + + has_own_embed_tokens: bool = False + """ + A flag that indicates this model has trained its own input embeddings. + """ + + +@overload +def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ... + + +@overload +def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ... + + +def supports_any_eagle( + model: type[object] | object, +) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]: + """Check if model supports any EAGLE variant (1, 2, or 3).""" + return supports_eagle(model) or supports_eagle3(model) + + +@runtime_checkable +class SupportsEagle(SupportsEagleBase, Protocol): """The interface required for models that support - EAGLE3 speculative decoding.""" + EAGLE-1 and EAGLE-2 speculative decoding.""" + + supports_eagle: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports EAGLE-1 and EAGLE-2 + speculative decoding. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@overload +def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ... + + +@overload +def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ... + + +def supports_eagle( + model: type[object] | object, +) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]: + return isinstance(model, SupportsEagle) + + +@runtime_checkable +class SupportsEagle3(SupportsEagleBase, Protocol): + """The interface required for models that support + EAGLE-3 speculative decoding.""" supports_eagle3: ClassVar[Literal[True]] = True """ - A flag that indicates this model supports EAGLE3 + A flag that indicates this model supports EAGLE-3 speculative decoding. Note: @@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol): def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: """ Set which layers should output auxiliary - hidden states for EAGLE3. + hidden states for EAGLE-3. Args: layers: Tuple of layer indices that should output auxiliary @@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol): def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: """ Get the layer indices that should output auxiliary hidden states - for EAGLE3. + for EAGLE-3. Returns: Tuple of layer indices for auxiliary hidden state outputs. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c49a1ea817f9..0a3f37c30ab5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -529,7 +529,9 @@ class LlamaModel(nn.Module): return loaded_params -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): +class LlamaForCausalLM( + nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index e8716d652415..660c8f1bb522 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa from vllm.model_executor.models.utils import extract_layer_index from .interfaces import SupportsMultiModal -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight) if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index ab2a9f6f06db..0287132c5637 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -179,6 +179,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name + process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 6edc9519dfbb..a3bcc5eeb32b 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -23,7 +23,7 @@ from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight logger = init_logger(__name__) @@ -324,6 +324,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): if "embed_tokens" in name: includes_embed_tokens = True model_weights[name] = loaded_weight + process_eagle_weight(self, name) skip_substrs = [] if not includes_draft_id_mapping: diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 0ca31913485d..d0cdb70aa857 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP from .minicpm import MiniCPMAttention as EagleMiniCPMAttention from .minicpm import MiniCPMMLP as EagleMiniCPMMLP from .minicpm import MiniCPMMoE as EagleMiniCPMMoE @@ -52,6 +52,7 @@ from .utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix, + process_eagle_weight, ) @@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module): return loaded_params -class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def transform(inputs): + name, loaded_weight = inputs + process_eagle_weight(self, name) + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e5663c8a057a..0d811fbc7585 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -19,6 +19,7 @@ from vllm.distributed import ( ) from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv @@ -825,3 +826,25 @@ direct_register_custom_op( fake_impl=sequence_parallel_chunk_impl_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) + + +def process_eagle_weight( + model: nn.Module, + name: str, +) -> None: + """ + Update EAGLE model flags based on loaded weight name. + This should be called during weight loading to detect if a model + has its own lm_head or embed_tokens weight. + Args: + model: The model instance (must support EAGLE) + name: The name of the weight to process + """ + if not supports_any_eagle(model): + return + + # To prevent overriding with target model's layers + if "lm_head" in name: + model.has_own_lm_head = True + if "embed_tokens" in name: + model.has_own_embed_tokens = True diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f3b34544f8d9..ed602f39d0f9 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -991,6 +991,7 @@ class EagleProposer: target_language_model = target_model.get_language_model() else: target_language_model = target_model + # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: if hasattr(target_language_model.model, "embed_tokens"): @@ -1002,52 +1003,92 @@ class EagleProposer: "Target model does not have 'embed_tokens' or 'embedding' attribute" ) - # Check if shapes match and we found the embedding - eagle_shape = self.model.model.embed_tokens.weight.shape - target_shape = target_embed_tokens.weight.shape - if eagle_shape == target_shape: - logger.info( - "Assuming the EAGLE head shares the same vocab embedding" - " with the target model." - ) - del self.model.model.embed_tokens - self.model.model.embed_tokens = target_embed_tokens + share_embeddings = False + if hasattr(self.model, "has_own_embed_tokens"): + # EAGLE model + if not self.model.has_own_embed_tokens: + share_embeddings = True + logger.info( + "Detected EAGLE model without its own embed_tokens in the" + " checkpoint. Sharing target model embedding weights with the" + " draft model." + ) + elif ( + isinstance(target_embed_tokens.weight, torch.Tensor) + and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) + and torch.equal( + target_embed_tokens.weight, self.model.model.embed_tokens.weight + ) + ): + share_embeddings = True + logger.info( + "Detected EAGLE model with embed_tokens identical to the target" + " model. Sharing target model embedding weights with the draft" + " model." + ) + else: + logger.info( + "Detected EAGLE model with distinct embed_tokens weights. " + "Keeping separate embedding weights from the target model." + ) else: + # MTP model + share_embeddings = True logger.info( - "The EAGLE head's vocab embedding will be loaded separately" - " from the target model." + "Detected MTP model. " + "Sharing target model embedding weights with the draft model." ) + + if share_embeddings: + if hasattr(self.model.model, "embed_tokens"): + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" + "The draft model's vocab embedding will be loaded separately" " from the target model." ) # share lm_head with the target model if needed - # some model definition do not define lm_head explicitly - # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3": - if hasattr(target_language_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head - else: - if ( - hasattr(self.model, "lm_head") - and hasattr(target_language_model, "lm_head") - and self.model.lm_head.weight.shape - == target_language_model.lm_head.weight.shape - ): + share_lm_head = False + if hasattr(self.model, "has_own_lm_head"): + # EAGLE model + if not self.model.has_own_lm_head: + share_lm_head = True logger.info( - "Assuming the EAGLE head shares the same lm_head" - " with the target model." + "Detected EAGLE model without its own lm_head in the checkpoint. " + "Sharing target model lm_head weights with the draft model." + ) + elif ( + hasattr(target_language_model, "lm_head") + and isinstance(target_language_model.lm_head.weight, torch.Tensor) + and isinstance(self.model.lm_head.weight, torch.Tensor) + and torch.equal( + target_language_model.lm_head.weight, self.model.lm_head.weight + ) + ): + share_lm_head = True + logger.info( + "Detected EAGLE model with lm_head identical to the target model. " + "Sharing target model lm_head weights with the draft model." ) - del self.model.lm_head - self.model.lm_head = target_language_model.lm_head else: logger.info( - "The EAGLE head's lm_head will be loaded separately" - " from the target model." + "Detected EAGLE model with distinct lm_head weights. " + "Keeping separate lm_head weights from the target model." ) + else: + # MTP model + share_lm_head = True + logger.info( + "Detected MTP model. " + "Sharing target model lm_head weights with the draft model." + ) + + if share_lm_head and hasattr(target_language_model, "lm_head"): + if hasattr(self.model, "lm_head"): + del self.model.lm_head + self.model.lm_head = target_language_model.lm_head @torch.inference_mode() def dummy_run(