[V1][Spec Decode] Share input embedding of target model with EAGLE draft model to free ~1GB for llama 3 model (#17326)

Co-authored-by: root <root@ekagra-8xh100.us-east5-a.c.serving-efficiency-poc.internal>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Ekagra Ranjan 2025-05-14 15:31:46 -04:00 committed by GitHub
parent 964472b966
commit 418d2f8bfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 19 deletions

View File

@ -105,6 +105,13 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)
# print the generated text
for output in outputs:
print("-" * 50)
print(f"prompt: {output.prompt}")
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)
if not hasattr(outputs, "metrics") or outputs.metrics is None:
return

View File

@ -8,6 +8,7 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -52,11 +53,15 @@ class LlamaModel(nn.Module):
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
@ -109,6 +114,12 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
@ -142,8 +153,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
skip_prefixes=None,
)
model_weights = {}
@ -151,5 +161,4 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
return loader.load_weights(model_weights.items())

View File

@ -8,6 +8,7 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
@ -91,11 +92,15 @@ class LlamaModel(nn.Module):
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
@ -306,12 +307,30 @@ class EagleProposer:
self.attn_layer_name = next(iter(draft_attn_layer_names))
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights:
logger.info(
"Loading EAGLE embedding weights from the target model.")
self.model.model.embed_tokens = target_model.model.embed_tokens
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
assert "model.embed_tokens.weight" not in loaded_weights, \
"For PP = 1, Eagle draft should share embed with target model"
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
assert "model.embed_tokens.weight" in loaded_weights, \
"For PP > 1, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_model.lm_head