mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:55:00 +08:00
Add support for Eagle with separate lm-head and embed_tokens layers (#28549)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
parent
085a525332
commit
e439c784fa
@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
|
|||||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||||
|
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
|
||||||
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
||||||
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
||||||
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
||||||
@ -335,6 +336,7 @@ def test_load_model(
|
|||||||
attn_backend,
|
attn_backend,
|
||||||
pp_size,
|
pp_size,
|
||||||
use_distinct_embed_tokens,
|
use_distinct_embed_tokens,
|
||||||
|
use_distinct_lm_head,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
@ -350,12 +352,13 @@ def test_load_model(
|
|||||||
|
|
||||||
# Setup draft model mock
|
# Setup draft model mock
|
||||||
mock_model = mock.MagicMock()
|
mock_model = mock.MagicMock()
|
||||||
|
mock_model.model = mock.MagicMock()
|
||||||
|
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
|
||||||
if use_distinct_embed_tokens:
|
if use_distinct_embed_tokens:
|
||||||
# Some models can have a different hidden size than the target model,
|
mock_model.model.embed_tokens = mock.MagicMock()
|
||||||
# so we test that their embed_tokens doesn't get overwritten
|
mock_model.has_own_lm_head = use_distinct_lm_head
|
||||||
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
|
if use_distinct_lm_head:
|
||||||
else:
|
mock_model.lm_head = mock.MagicMock()
|
||||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
|
||||||
|
|
||||||
mock_get_model.return_value = mock_model
|
mock_get_model.return_value = mock_model
|
||||||
|
|
||||||
@ -391,15 +394,13 @@ def test_load_model(
|
|||||||
|
|
||||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||||
target_model.model = mock.MagicMock()
|
target_model.model = mock.MagicMock()
|
||||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
target_model.lm_head = mock.MagicMock()
|
||||||
|
target_model.model.embed_tokens = mock.MagicMock()
|
||||||
|
|
||||||
from vllm.model_executor.models import SupportsMultiModal
|
from vllm.model_executor.models import SupportsMultiModal
|
||||||
|
|
||||||
assert not isinstance(target_model, SupportsMultiModal)
|
assert not isinstance(target_model, SupportsMultiModal)
|
||||||
|
|
||||||
if method == "eagle":
|
|
||||||
target_model.lm_head = mock.MagicMock()
|
|
||||||
|
|
||||||
# Create proposer using the helper function
|
# Create proposer using the helper function
|
||||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||||
|
|
||||||
@ -409,18 +410,18 @@ def test_load_model(
|
|||||||
# Verify common interactions
|
# Verify common interactions
|
||||||
mock_get_model.assert_called_once()
|
mock_get_model.assert_called_once()
|
||||||
|
|
||||||
# Verify that EAGLE models gain the lm head from the target model
|
# Verify that the lm head is set correctly
|
||||||
if method == "eagle":
|
if use_distinct_lm_head:
|
||||||
assert proposer.model.lm_head == target_model.lm_head
|
assert proposer.model.lm_head is not target_model.lm_head
|
||||||
|
else:
|
||||||
|
assert proposer.model.lm_head is target_model.lm_head
|
||||||
|
|
||||||
# Verify that the embed tokens are set correctly
|
# Verify that the embed tokens are set correctly
|
||||||
# If pp_size is > 1, the embed tokens should be distinct
|
# If pp_size is > 1, the embed tokens should be distinct
|
||||||
if pp_size > 1 or use_distinct_embed_tokens:
|
if pp_size > 1 or use_distinct_embed_tokens:
|
||||||
assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
|
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
# When pp_size is 1 and the draft and target models have
|
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
|
||||||
# embed_tokens of the same shape, they should be shared.
|
|
||||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||||
|
|||||||
@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
|
|||||||
mock_model = mock.MagicMock()
|
mock_model = mock.MagicMock()
|
||||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||||
mock_get_model.return_value = mock_model
|
mock_get_model.return_value = mock_model
|
||||||
|
# MTP does not have its own embed_tokens or lm_head
|
||||||
|
# so it should share them with the target model
|
||||||
|
mock_model.has_own_embed_tokens = False
|
||||||
|
mock_model.has_own_lm_head = False
|
||||||
|
|
||||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import (
|
|||||||
)
|
)
|
||||||
from vllm.utils import init_logger
|
from vllm.utils import init_logger
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
|||||||
name, loaded_weight = inputs
|
name, loaded_weight = inputs
|
||||||
if "lm_head" not in name:
|
if "lm_head" not in name:
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
|
process_eagle_weight(self, name)
|
||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
|
|||||||
@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekV2ForCausalLM(
|
class DeepseekV2ForCausalLM(
|
||||||
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
|
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
@ -932,13 +932,73 @@ def supports_transcription(
|
|||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class SupportsEagle3(Protocol):
|
class SupportsEagleBase(Protocol):
|
||||||
|
"""Base interface for models that support EAGLE-based speculative decoding."""
|
||||||
|
|
||||||
|
has_own_lm_head: bool = False
|
||||||
|
"""
|
||||||
|
A flag that indicates this model has trained its own lm_head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
has_own_embed_tokens: bool = False
|
||||||
|
"""
|
||||||
|
A flag that indicates this model has trained its own input embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_any_eagle(
|
||||||
|
model: type[object] | object,
|
||||||
|
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
|
||||||
|
"""Check if model supports any EAGLE variant (1, 2, or 3)."""
|
||||||
|
return supports_eagle(model) or supports_eagle3(model)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsEagle(SupportsEagleBase, Protocol):
|
||||||
"""The interface required for models that support
|
"""The interface required for models that support
|
||||||
EAGLE3 speculative decoding."""
|
EAGLE-1 and EAGLE-2 speculative decoding."""
|
||||||
|
|
||||||
|
supports_eagle: ClassVar[Literal[True]] = True
|
||||||
|
"""
|
||||||
|
A flag that indicates this model supports EAGLE-1 and EAGLE-2
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
There is no need to redefine this flag if this class is in the
|
||||||
|
MRO of your model class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_eagle(
|
||||||
|
model: type[object] | object,
|
||||||
|
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
|
||||||
|
return isinstance(model, SupportsEagle)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsEagle3(SupportsEagleBase, Protocol):
|
||||||
|
"""The interface required for models that support
|
||||||
|
EAGLE-3 speculative decoding."""
|
||||||
|
|
||||||
supports_eagle3: ClassVar[Literal[True]] = True
|
supports_eagle3: ClassVar[Literal[True]] = True
|
||||||
"""
|
"""
|
||||||
A flag that indicates this model supports EAGLE3
|
A flag that indicates this model supports EAGLE-3
|
||||||
speculative decoding.
|
speculative decoding.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
|
|||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Set which layers should output auxiliary
|
Set which layers should output auxiliary
|
||||||
hidden states for EAGLE3.
|
hidden states for EAGLE-3.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layers: Tuple of layer indices that should output auxiliary
|
layers: Tuple of layer indices that should output auxiliary
|
||||||
@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol):
|
|||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
Get the layer indices that should output auxiliary hidden states
|
Get the layer indices that should output auxiliary hidden states
|
||||||
for EAGLE3.
|
for EAGLE-3.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of layer indices for auxiliary hidden state outputs.
|
Tuple of layer indices for auxiliary hidden state outputs.
|
||||||
|
|||||||
@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@ -529,7 +529,9 @@ class LlamaModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
class LlamaForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa
|
|||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||||
if "lm_head" not in name:
|
if "lm_head" not in name:
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
|
process_eagle_weight(self, name)
|
||||||
return name, weight
|
return name, weight
|
||||||
|
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -179,6 +179,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
name, loaded_weight = inputs
|
name, loaded_weight = inputs
|
||||||
if "lm_head" not in name:
|
if "lm_head" not in name:
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
|
process_eagle_weight(self, name)
|
||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -324,6 +324,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
if "embed_tokens" in name:
|
if "embed_tokens" in name:
|
||||||
includes_embed_tokens = True
|
includes_embed_tokens = True
|
||||||
model_weights[name] = loaded_weight
|
model_weights[name] = loaded_weight
|
||||||
|
process_eagle_weight(self, name)
|
||||||
|
|
||||||
skip_substrs = []
|
skip_substrs = []
|
||||||
if not includes_draft_id_mapping:
|
if not includes_draft_id_mapping:
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
|
||||||
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
|
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
|
||||||
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
|
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
|
||||||
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
|
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
|
||||||
@ -52,6 +52,7 @@ from .utils import (
|
|||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
|
process_eagle_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
def transform(inputs):
|
||||||
|
name, loaded_weight = inputs
|
||||||
|
process_eagle_weight(self, name)
|
||||||
|
return name, loaded_weight
|
||||||
|
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(map(transform, weights))
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||||
from vllm.multimodal import NestedTensors
|
from vllm.multimodal import NestedTensors
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
@ -825,3 +826,25 @@ direct_register_custom_op(
|
|||||||
fake_impl=sequence_parallel_chunk_impl_fake,
|
fake_impl=sequence_parallel_chunk_impl_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_eagle_weight(
|
||||||
|
model: nn.Module,
|
||||||
|
name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update EAGLE model flags based on loaded weight name.
|
||||||
|
This should be called during weight loading to detect if a model
|
||||||
|
has its own lm_head or embed_tokens weight.
|
||||||
|
Args:
|
||||||
|
model: The model instance (must support EAGLE)
|
||||||
|
name: The name of the weight to process
|
||||||
|
"""
|
||||||
|
if not supports_any_eagle(model):
|
||||||
|
return
|
||||||
|
|
||||||
|
# To prevent overriding with target model's layers
|
||||||
|
if "lm_head" in name:
|
||||||
|
model.has_own_lm_head = True
|
||||||
|
if "embed_tokens" in name:
|
||||||
|
model.has_own_embed_tokens = True
|
||||||
|
|||||||
@ -991,6 +991,7 @@ class EagleProposer:
|
|||||||
target_language_model = target_model.get_language_model()
|
target_language_model = target_model.get_language_model()
|
||||||
else:
|
else:
|
||||||
target_language_model = target_model
|
target_language_model = target_model
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
if hasattr(target_language_model.model, "embed_tokens"):
|
if hasattr(target_language_model.model, "embed_tokens"):
|
||||||
@ -1002,52 +1003,92 @@ class EagleProposer:
|
|||||||
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if shapes match and we found the embedding
|
share_embeddings = False
|
||||||
eagle_shape = self.model.model.embed_tokens.weight.shape
|
if hasattr(self.model, "has_own_embed_tokens"):
|
||||||
target_shape = target_embed_tokens.weight.shape
|
# EAGLE model
|
||||||
if eagle_shape == target_shape:
|
if not self.model.has_own_embed_tokens:
|
||||||
logger.info(
|
share_embeddings = True
|
||||||
"Assuming the EAGLE head shares the same vocab embedding"
|
logger.info(
|
||||||
" with the target model."
|
"Detected EAGLE model without its own embed_tokens in the"
|
||||||
)
|
" checkpoint. Sharing target model embedding weights with the"
|
||||||
del self.model.model.embed_tokens
|
" draft model."
|
||||||
self.model.model.embed_tokens = target_embed_tokens
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(target_embed_tokens.weight, torch.Tensor)
|
||||||
|
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
|
||||||
|
and torch.equal(
|
||||||
|
target_embed_tokens.weight, self.model.model.embed_tokens.weight
|
||||||
|
)
|
||||||
|
):
|
||||||
|
share_embeddings = True
|
||||||
|
logger.info(
|
||||||
|
"Detected EAGLE model with embed_tokens identical to the target"
|
||||||
|
" model. Sharing target model embedding weights with the draft"
|
||||||
|
" model."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Detected EAGLE model with distinct embed_tokens weights. "
|
||||||
|
"Keeping separate embedding weights from the target model."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
# MTP model
|
||||||
|
share_embeddings = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head's vocab embedding will be loaded separately"
|
"Detected MTP model. "
|
||||||
" from the target model."
|
"Sharing target model embedding weights with the draft model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if share_embeddings:
|
||||||
|
if hasattr(self.model.model, "embed_tokens"):
|
||||||
|
del self.model.model.embed_tokens
|
||||||
|
self.model.model.embed_tokens = target_embed_tokens
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head's vocab embedding will be loaded separately"
|
"The draft model's vocab embedding will be loaded separately"
|
||||||
" from the target model."
|
" from the target model."
|
||||||
)
|
)
|
||||||
|
|
||||||
# share lm_head with the target model if needed
|
# share lm_head with the target model if needed
|
||||||
# some model definition do not define lm_head explicitly
|
share_lm_head = False
|
||||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
if hasattr(self.model, "has_own_lm_head"):
|
||||||
if self.vllm_config.speculative_config.method != "eagle3":
|
# EAGLE model
|
||||||
if hasattr(target_language_model, "lm_head"):
|
if not self.model.has_own_lm_head:
|
||||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
share_lm_head = True
|
||||||
self.model.lm_head = target_language_model.lm_head
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
hasattr(self.model, "lm_head")
|
|
||||||
and hasattr(target_language_model, "lm_head")
|
|
||||||
and self.model.lm_head.weight.shape
|
|
||||||
== target_language_model.lm_head.weight.shape
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Assuming the EAGLE head shares the same lm_head"
|
"Detected EAGLE model without its own lm_head in the checkpoint. "
|
||||||
" with the target model."
|
"Sharing target model lm_head weights with the draft model."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
hasattr(target_language_model, "lm_head")
|
||||||
|
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
|
||||||
|
and isinstance(self.model.lm_head.weight, torch.Tensor)
|
||||||
|
and torch.equal(
|
||||||
|
target_language_model.lm_head.weight, self.model.lm_head.weight
|
||||||
|
)
|
||||||
|
):
|
||||||
|
share_lm_head = True
|
||||||
|
logger.info(
|
||||||
|
"Detected EAGLE model with lm_head identical to the target model. "
|
||||||
|
"Sharing target model lm_head weights with the draft model."
|
||||||
)
|
)
|
||||||
del self.model.lm_head
|
|
||||||
self.model.lm_head = target_language_model.lm_head
|
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head's lm_head will be loaded separately"
|
"Detected EAGLE model with distinct lm_head weights. "
|
||||||
" from the target model."
|
"Keeping separate lm_head weights from the target model."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# MTP model
|
||||||
|
share_lm_head = True
|
||||||
|
logger.info(
|
||||||
|
"Detected MTP model. "
|
||||||
|
"Sharing target model lm_head weights with the draft model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if share_lm_head and hasattr(target_language_model, "lm_head"):
|
||||||
|
if hasattr(self.model, "lm_head"):
|
||||||
|
del self.model.lm_head
|
||||||
|
self.model.lm_head = target_language_model.lm_head
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(
|
def dummy_run(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user