mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
Add support for Eagle with separate lm-head and embed_tokens layers (#28549)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
parent
085a525332
commit
e439c784fa
@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
||||
@ -335,6 +336,7 @@ def test_load_model(
|
||||
attn_backend,
|
||||
pp_size,
|
||||
use_distinct_embed_tokens,
|
||||
use_distinct_lm_head,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
@ -350,12 +352,13 @@ def test_load_model(
|
||||
|
||||
# Setup draft model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model = mock.MagicMock()
|
||||
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
|
||||
if use_distinct_embed_tokens:
|
||||
# Some models can have a different hidden size than the target model,
|
||||
# so we test that their embed_tokens doesn't get overwritten
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
|
||||
else:
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_model.model.embed_tokens = mock.MagicMock()
|
||||
mock_model.has_own_lm_head = use_distinct_lm_head
|
||||
if use_distinct_lm_head:
|
||||
mock_model.lm_head = mock.MagicMock()
|
||||
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
@ -391,15 +394,13 @@ def test_load_model(
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
target_model.model.embed_tokens = mock.MagicMock()
|
||||
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
|
||||
assert not isinstance(target_model, SupportsMultiModal)
|
||||
|
||||
if method == "eagle":
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||
|
||||
@ -409,18 +410,18 @@ def test_load_model(
|
||||
# Verify common interactions
|
||||
mock_get_model.assert_called_once()
|
||||
|
||||
# Verify that EAGLE models gain the lm head from the target model
|
||||
if method == "eagle":
|
||||
assert proposer.model.lm_head == target_model.lm_head
|
||||
# Verify that the lm head is set correctly
|
||||
if use_distinct_lm_head:
|
||||
assert proposer.model.lm_head is not target_model.lm_head
|
||||
else:
|
||||
assert proposer.model.lm_head is target_model.lm_head
|
||||
|
||||
# Verify that the embed tokens are set correctly
|
||||
# If pp_size is > 1, the embed tokens should be distinct
|
||||
if pp_size > 1 or use_distinct_embed_tokens:
|
||||
assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
|
||||
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
|
||||
else:
|
||||
# When pp_size is 1 and the draft and target models have
|
||||
# embed_tokens of the same shape, they should be shared.
|
||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
|
||||
@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_get_model.return_value = mock_model
|
||||
# MTP does not have its own embed_tokens or lm_head
|
||||
# so it should share them with the target model
|
||||
mock_model.has_own_embed_tokens = False
|
||||
mock_model.has_own_lm_head = False
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import (
|
||||
)
|
||||
from vllm.utils import init_logger
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
name, loaded_weight = inputs
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
process_eagle_weight(self, name)
|
||||
return name, loaded_weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
|
||||
@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
|
||||
|
||||
|
||||
class DeepseekV2ForCausalLM(
|
||||
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
|
||||
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
|
||||
@ -932,13 +932,73 @@ def supports_transcription(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEagle3(Protocol):
|
||||
class SupportsEagleBase(Protocol):
|
||||
"""Base interface for models that support EAGLE-based speculative decoding."""
|
||||
|
||||
has_own_lm_head: bool = False
|
||||
"""
|
||||
A flag that indicates this model has trained its own lm_head.
|
||||
"""
|
||||
|
||||
has_own_embed_tokens: bool = False
|
||||
"""
|
||||
A flag that indicates this model has trained its own input embeddings.
|
||||
"""
|
||||
|
||||
|
||||
@overload
|
||||
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
|
||||
|
||||
|
||||
def supports_any_eagle(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
|
||||
"""Check if model supports any EAGLE variant (1, 2, or 3)."""
|
||||
return supports_eagle(model) or supports_eagle3(model)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEagle(SupportsEagleBase, Protocol):
|
||||
"""The interface required for models that support
|
||||
EAGLE3 speculative decoding."""
|
||||
EAGLE-1 and EAGLE-2 speculative decoding."""
|
||||
|
||||
supports_eagle: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports EAGLE-1 and EAGLE-2
|
||||
speculative decoding.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
|
||||
@overload
|
||||
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
|
||||
|
||||
|
||||
def supports_eagle(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
|
||||
return isinstance(model, SupportsEagle)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEagle3(SupportsEagleBase, Protocol):
|
||||
"""The interface required for models that support
|
||||
EAGLE-3 speculative decoding."""
|
||||
|
||||
supports_eagle3: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports EAGLE3
|
||||
A flag that indicates this model supports EAGLE-3
|
||||
speculative decoding.
|
||||
|
||||
Note:
|
||||
@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
"""
|
||||
Set which layers should output auxiliary
|
||||
hidden states for EAGLE3.
|
||||
hidden states for EAGLE-3.
|
||||
|
||||
Args:
|
||||
layers: Tuple of layer indices that should output auxiliary
|
||||
@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol):
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the layer indices that should output auxiliary hidden states
|
||||
for EAGLE3.
|
||||
for EAGLE-3.
|
||||
|
||||
Returns:
|
||||
Tuple of layer indices for auxiliary hidden state outputs.
|
||||
|
||||
@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -529,7 +529,9 @@ class LlamaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
class LlamaForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
process_eagle_weight(self, name)
|
||||
return name, weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -179,6 +179,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
name, loaded_weight = inputs
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
process_eagle_weight(self, name)
|
||||
return name, loaded_weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -324,6 +324,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
if "embed_tokens" in name:
|
||||
includes_embed_tokens = True
|
||||
model_weights[name] = loaded_weight
|
||||
process_eagle_weight(self, name)
|
||||
|
||||
skip_substrs = []
|
||||
if not includes_draft_id_mapping:
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
|
||||
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
|
||||
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
|
||||
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
|
||||
@ -52,6 +52,7 @@ from .utils import (
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
process_eagle_weight,
|
||||
)
|
||||
|
||||
|
||||
@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
process_eagle_weight(self, name)
|
||||
return name, loaded_weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(map(transform, weights))
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@ -825,3 +826,25 @@ direct_register_custom_op(
|
||||
fake_impl=sequence_parallel_chunk_impl_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def process_eagle_weight(
|
||||
model: nn.Module,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update EAGLE model flags based on loaded weight name.
|
||||
This should be called during weight loading to detect if a model
|
||||
has its own lm_head or embed_tokens weight.
|
||||
Args:
|
||||
model: The model instance (must support EAGLE)
|
||||
name: The name of the weight to process
|
||||
"""
|
||||
if not supports_any_eagle(model):
|
||||
return
|
||||
|
||||
# To prevent overriding with target model's layers
|
||||
if "lm_head" in name:
|
||||
model.has_own_lm_head = True
|
||||
if "embed_tokens" in name:
|
||||
model.has_own_embed_tokens = True
|
||||
|
||||
@ -991,6 +991,7 @@ class EagleProposer:
|
||||
target_language_model = target_model.get_language_model()
|
||||
else:
|
||||
target_language_model = target_model
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1:
|
||||
if hasattr(target_language_model.model, "embed_tokens"):
|
||||
@ -1002,52 +1003,92 @@ class EagleProposer:
|
||||
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
||||
)
|
||||
|
||||
# Check if shapes match and we found the embedding
|
||||
eagle_shape = self.model.model.embed_tokens.weight.shape
|
||||
target_shape = target_embed_tokens.weight.shape
|
||||
if eagle_shape == target_shape:
|
||||
logger.info(
|
||||
"Assuming the EAGLE head shares the same vocab embedding"
|
||||
" with the target model."
|
||||
)
|
||||
del self.model.model.embed_tokens
|
||||
self.model.model.embed_tokens = target_embed_tokens
|
||||
share_embeddings = False
|
||||
if hasattr(self.model, "has_own_embed_tokens"):
|
||||
# EAGLE model
|
||||
if not self.model.has_own_embed_tokens:
|
||||
share_embeddings = True
|
||||
logger.info(
|
||||
"Detected EAGLE model without its own embed_tokens in the"
|
||||
" checkpoint. Sharing target model embedding weights with the"
|
||||
" draft model."
|
||||
)
|
||||
elif (
|
||||
isinstance(target_embed_tokens.weight, torch.Tensor)
|
||||
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
|
||||
and torch.equal(
|
||||
target_embed_tokens.weight, self.model.model.embed_tokens.weight
|
||||
)
|
||||
):
|
||||
share_embeddings = True
|
||||
logger.info(
|
||||
"Detected EAGLE model with embed_tokens identical to the target"
|
||||
" model. Sharing target model embedding weights with the draft"
|
||||
" model."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Detected EAGLE model with distinct embed_tokens weights. "
|
||||
"Keeping separate embedding weights from the target model."
|
||||
)
|
||||
else:
|
||||
# MTP model
|
||||
share_embeddings = True
|
||||
logger.info(
|
||||
"The EAGLE head's vocab embedding will be loaded separately"
|
||||
" from the target model."
|
||||
"Detected MTP model. "
|
||||
"Sharing target model embedding weights with the draft model."
|
||||
)
|
||||
|
||||
if share_embeddings:
|
||||
if hasattr(self.model.model, "embed_tokens"):
|
||||
del self.model.model.embed_tokens
|
||||
self.model.model.embed_tokens = target_embed_tokens
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's vocab embedding will be loaded separately"
|
||||
"The draft model's vocab embedding will be loaded separately"
|
||||
" from the target model."
|
||||
)
|
||||
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3":
|
||||
if hasattr(target_language_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
if (
|
||||
hasattr(self.model, "lm_head")
|
||||
and hasattr(target_language_model, "lm_head")
|
||||
and self.model.lm_head.weight.shape
|
||||
== target_language_model.lm_head.weight.shape
|
||||
):
|
||||
share_lm_head = False
|
||||
if hasattr(self.model, "has_own_lm_head"):
|
||||
# EAGLE model
|
||||
if not self.model.has_own_lm_head:
|
||||
share_lm_head = True
|
||||
logger.info(
|
||||
"Assuming the EAGLE head shares the same lm_head"
|
||||
" with the target model."
|
||||
"Detected EAGLE model without its own lm_head in the checkpoint. "
|
||||
"Sharing target model lm_head weights with the draft model."
|
||||
)
|
||||
elif (
|
||||
hasattr(target_language_model, "lm_head")
|
||||
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
|
||||
and isinstance(self.model.lm_head.weight, torch.Tensor)
|
||||
and torch.equal(
|
||||
target_language_model.lm_head.weight, self.model.lm_head.weight
|
||||
)
|
||||
):
|
||||
share_lm_head = True
|
||||
logger.info(
|
||||
"Detected EAGLE model with lm_head identical to the target model. "
|
||||
"Sharing target model lm_head weights with the draft model."
|
||||
)
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's lm_head will be loaded separately"
|
||||
" from the target model."
|
||||
"Detected EAGLE model with distinct lm_head weights. "
|
||||
"Keeping separate lm_head weights from the target model."
|
||||
)
|
||||
else:
|
||||
# MTP model
|
||||
share_lm_head = True
|
||||
logger.info(
|
||||
"Detected MTP model. "
|
||||
"Sharing target model lm_head weights with the draft model."
|
||||
)
|
||||
|
||||
if share_lm_head and hasattr(target_language_model, "lm_head"):
|
||||
if hasattr(self.model, "lm_head"):
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user