[Bugfix] Fix EAGLE vocab embedding construction for Llama 70B (#19033)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett 2025-06-05 22:10:08 -04:00 committed by GitHub
parent c8134bea15
commit 3465b87ef8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 70 additions and 47 deletions

View File

@ -5,11 +5,11 @@ import copy
import itertools
import torch
import triton
from weight_shapes import WEIGHT_SHAPES
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from vllm.triton_utils import triton
@triton.testing.perf_report(

View File

@ -9,6 +9,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.v1.spec_decode.eagle import EagleProposer
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@ -113,21 +114,26 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)
@pytest.mark.parametrize(
"method,proposer_helper,draft_model_dir,target_attribute_path", [
("eagle", lambda k: _create_proposer("eagle", k), eagle_dir,
('lm_head', )),
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
('model', 'embed_tokens')),
])
@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
])
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):
# Setup model mock
proposer_helper, pp_size, use_distinct_embed_tokens):
# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
# Some models can have a different hidden size than the target model,
# so we test that their embed_tokens doesn't get overwritten
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
else:
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
# Setup mocks for attention layers
@ -145,22 +151,24 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_pp_group.world_size = pp_size
mock_get_pp_group.return_value = mock_pp_group
# Setup target model with the appropriate attributes
target_model = mock.MagicMock()
# Setup the target model mock with a custom class so that
# isinstance() checks match the expected type.
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
# Create the necessary attributes on the target model
current_obj = target_model
for i, attr in enumerate(target_attribute_path):
if i == len(target_attribute_path) - 1:
# Set the last attribute in the path to a MagicMock
setattr(current_obj, attr, mock.MagicMock())
else:
# Create intermediate objects if needed
setattr(current_obj, attr, mock.MagicMock())
current_obj = getattr(current_obj, attr)
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)
if method == "eagle":
target_model.lm_head = mock.MagicMock()
# Create proposer using the helper function
proposer = proposer_helper(k=8)
@ -171,10 +179,18 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Verify common interactions
mock_get_model.assert_called_once()
# Verify the specific attribute sharing based on the method
# Verify that EAGLE models gain the lm head from the target model
if method == "eagle":
assert proposer.model.lm_head == target_model.lm_head
# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens != \
target_model.model.embed_tokens
else:
# When pp_size is 1 and the draft and target models have
# embed_tokens of the same shape, they should be shared.
assert proposer.model.model.embed_tokens == \
target_model.model.embed_tokens

View File

@ -55,13 +55,11 @@ class LlamaModel(nn.Module):
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
@ -164,4 +162,4 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
loader.load_weights(model_weights.items())

View File

@ -10,7 +10,6 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
@ -95,13 +94,11 @@ class LlamaModel(nn.Module):
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
@ -240,6 +237,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
@ -248,12 +246,18 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=["draft_id_to_target_id"] \
if not includes_draft_id_mapping else None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())

View File

@ -173,6 +173,7 @@ class CudaPlatformBase(Platform):
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

View File

@ -899,6 +899,7 @@ class DeviceMemoryProfiler:
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
return current_platform.get_current_memory_usage(self.device)
def __enter__(self):

View File

@ -330,16 +330,19 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_model.model.embed_tokens.weight.shape:
logger.info(
"The EAGLE head shares the same vocab embedding" \
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
"The EAGLE head's vocab embedding will be loaded separately" \
" from the target model."
)
# share lm_head with the target model if needed