mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 23:18:46 +08:00
[V1][Spec Decode] Share input embedding of target model with EAGLE draft model to free ~1GB for llama 3 model (#17326)
Co-authored-by: root <root@ekagra-8xh100.us-east5-a.c.serving-efficiency-poc.internal> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
964472b966
commit
418d2f8bfb
@ -105,6 +105,13 @@ def main():
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
for output in outputs:
|
||||
print("-" * 50)
|
||||
print(f"prompt: {output.prompt}")
|
||||
print(f"generated text: {output.outputs[0].text}")
|
||||
print("-" * 50)
|
||||
|
||||
if not hasattr(outputs, "metrics") or outputs.metrics is None:
|
||||
return
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -52,11 +53,15 @@ class LlamaModel(nn.Module):
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size > 1:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
@ -109,6 +114,12 @@ class LlamaModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size == 1 and \
|
||||
"embed_tokens." in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
@ -142,8 +153,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
skip_prefixes=None,
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
@ -151,5 +161,4 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
||||
return loader.load_weights(model_weights.items())
|
||||
|
||||
@ -8,6 +8,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@ -91,11 +92,15 @@ class LlamaModel(nn.Module):
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size > 1:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
|
||||
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
@ -306,12 +307,30 @@ class EagleProposer:
|
||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||
loaded_weights = self.model.load_weights(
|
||||
loader.get_all_weights(draft_model_config, self.model))
|
||||
if self.vllm_config.speculative_config.method == "eagle3":
|
||||
if "model.embed_tokens.weight" not in loaded_weights:
|
||||
logger.info(
|
||||
"Loading EAGLE embedding weights from the target model.")
|
||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1:
|
||||
assert "model.embed_tokens.weight" not in loaded_weights, \
|
||||
"For PP = 1, Eagle draft should share embed with target model"
|
||||
logger.info(
|
||||
"The EAGLE head shares the same vocab embedding" \
|
||||
" with the target model."
|
||||
)
|
||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||
else:
|
||||
assert "model.embed_tokens.weight" in loaded_weights, \
|
||||
"For PP > 1, Eagle draft checkpoint should its own copy of "
|
||||
" the model.embed_tokens.weight"
|
||||
logger.info(
|
||||
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
|
||||
" weights instead of sharing them with the target model."
|
||||
)
|
||||
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3" and \
|
||||
hasattr(target_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user