mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[V1][Spec Decode] EAGLE-3 Support (#16937)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
parent
70116459c3
commit
a0e619e62a
@ -52,8 +52,8 @@ def main():
|
|||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
max_model_len = 2048
|
max_model_len = 2048
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def main():
|
|||||||
max_num_seqs=args.max_num_seqs,
|
max_num_seqs=args.max_num_seqs,
|
||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "eagle",
|
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
|
||||||
"model": eagle_dir,
|
"model": eagle_dir,
|
||||||
"num_speculative_tokens": args.num_spec_tokens,
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
"draft_tensor_parallel_size": args.draft_tp,
|
"draft_tensor_parallel_size": args.draft_tp,
|
||||||
@ -95,6 +95,9 @@ def main():
|
|||||||
outputs = llm.generate(prompt_token_ids=prompt_ids,
|
outputs = llm.generate(prompt_token_ids=prompt_ids,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
|
if not hasattr(outputs, "metrics") or outputs.metrics is None:
|
||||||
|
return
|
||||||
|
|
||||||
# calculate the average number of accepted tokens per forward pass, +1 is
|
# calculate the average number of accepted tokens per forward pass, +1 is
|
||||||
# to account for the token from the target model that's always going to be
|
# to account for the token from the target model that's always going to be
|
||||||
# accepted
|
# accepted
|
||||||
@ -109,6 +112,11 @@ def main():
|
|||||||
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
|
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
# print acceptance at each token position
|
||||||
|
for i in range(len(acceptance_counts)):
|
||||||
|
print(f"acceptance at token {i}:"
|
||||||
|
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -393,6 +393,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
|
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
|
||||||
|
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
|
||||||
|
trust_remote_code=True,
|
||||||
|
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||||
|
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_MODELS = {
|
_TRANSFORMERS_MODELS = {
|
||||||
|
|||||||
@ -50,12 +50,15 @@ def sampling_config():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_name():
|
def model_name():
|
||||||
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def eagle_model_name():
|
def eagle_model_name():
|
||||||
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
|
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
|
|
||||||
|
def eagle3_model_name():
|
||||||
|
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
@ -102,12 +105,13 @@ def test_ngram_correctness(
|
|||||||
del spec_llm
|
del spec_llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
eagle_model_name: str,
|
use_eagle3: bool,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
@ -116,18 +120,22 @@ def test_eagle_correctness(
|
|||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=2048)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
|
||||||
|
spec_model_name = eagle3_model_name(
|
||||||
|
) if use_eagle3 else eagle_model_name()
|
||||||
spec_llm = LLM(
|
spec_llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "eagle",
|
"method": "eagle3" if use_eagle3 else "eagle",
|
||||||
"model": eagle_model_name,
|
"model": spec_model_name,
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
|
"max_model_len": 2048,
|
||||||
},
|
},
|
||||||
max_model_len=1024,
|
max_model_len=2048,
|
||||||
)
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
|
|||||||
@ -2339,9 +2339,10 @@ class SpeculativeConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Automatically detect the method
|
# Automatically detect the method
|
||||||
if self.method == 'eagle':
|
if self.method in ('eagle', 'eagle3'):
|
||||||
pass
|
pass
|
||||||
elif "eagle-" in self.draft_model_config.model.lower():
|
elif "eagle-" in self.draft_model_config.model.lower() or \
|
||||||
|
"eagle3-" in self.draft_model_config.model.lower():
|
||||||
self.method = "eagle"
|
self.method = "eagle"
|
||||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||||
self.method = "medusa"
|
self.method = "medusa"
|
||||||
@ -2352,7 +2353,7 @@ class SpeculativeConfig:
|
|||||||
self.method = "draft_model"
|
self.method = "draft_model"
|
||||||
|
|
||||||
# Replace hf_config for EAGLE draft_model
|
# Replace hf_config for EAGLE draft_model
|
||||||
if self.method == "eagle":
|
if self.method in ("eagle", "eagle3"):
|
||||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Chunked prefill and EAGLE are not compatible "
|
"Chunked prefill and EAGLE are not compatible "
|
||||||
@ -2549,6 +2550,12 @@ class SpeculativeConfig:
|
|||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{self.disable_by_batch_size=}")
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
|
if self.method == "eagle3" and self.target_model_config and \
|
||||||
|
"llama" not in self.target_model_config.hf_text_config.model_type:
|
||||||
|
raise ValueError(
|
||||||
|
"Eagle3 is only supported for Llama models. "
|
||||||
|
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_lookahead_slots(self) -> int:
|
def num_lookahead_slots(self) -> int:
|
||||||
"""The number of additional slots the scheduler should allocate per
|
"""The number of additional slots the scheduler should allocate per
|
||||||
|
|||||||
@ -1459,7 +1459,7 @@ class EngineArgs:
|
|||||||
if speculative_method:
|
if speculative_method:
|
||||||
if speculative_method in ("ngram", "[ngram]"):
|
if speculative_method in ("ngram", "[ngram]"):
|
||||||
is_ngram_enabled = True
|
is_ngram_enabled = True
|
||||||
elif speculative_method == "eagle":
|
elif speculative_method in ("eagle", "eagle3"):
|
||||||
is_eagle_enabled = True
|
is_eagle_enabled = True
|
||||||
else:
|
else:
|
||||||
speculative_model = self.speculative_config.get("model")
|
speculative_model = self.speculative_config.get("model")
|
||||||
|
|||||||
@ -330,6 +330,8 @@ class LlamaModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.aux_hidden_state_layers: tuple[int] = tuple()
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
@ -355,7 +357,11 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
aux_hidden_states = []
|
||||||
|
for idx, layer in enumerate(
|
||||||
|
self.layers[self.start_layer:self.end_layer]):
|
||||||
|
if idx in self.aux_hidden_state_layers:
|
||||||
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -365,6 +371,9 @@ class LlamaModel(nn.Module):
|
|||||||
})
|
})
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) > 0:
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
@ -517,6 +526,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||||
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
def _init_model(self,
|
def _init_model(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
|||||||
@ -82,7 +82,8 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
return hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states, hidden_states
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
|||||||
232
vllm/model_executor/models/llama_eagle3.py
Normal file
232
vllm/model_executor/models/llama_eagle3.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM)
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, quant_config=quant_config, prefix=prefix)
|
||||||
|
|
||||||
|
# override qkv
|
||||||
|
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||||
|
2 * self.hidden_size,
|
||||||
|
self.self_attn.head_dim,
|
||||||
|
self.self_attn.total_num_heads,
|
||||||
|
self.self_attn.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "qkv_proj"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
embeds: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
embeds = self.input_layernorm(embeds)
|
||||||
|
hidden_states = self.hidden_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
start_layer_id: int = 0,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = model_config.hf_config
|
||||||
|
self.vocab_size = self.config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
LlamaDecoderLayer(
|
||||||
|
self.config,
|
||||||
|
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
if hasattr(self.config, "target_hidden_size"):
|
||||||
|
self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
|
||||||
|
self.config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
else:
|
||||||
|
self.fc = torch.nn.Linear(self.config.hidden_size * 3,
|
||||||
|
self.config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.norm = RMSNorm(
|
||||||
|
self.config.hidden_size,
|
||||||
|
eps=self.config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
|
||||||
|
hidden_states = self.fc(hidden_states)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
hidden_states, residual = self.layers[0](
|
||||||
|
positions,
|
||||||
|
input_embeds,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states, hidden_prenorm
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if 'midlayer.' in name:
|
||||||
|
name = name.replace('midlayer.', 'layers.0.')
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
|
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.config = model_config.hf_config
|
||||||
|
self.model = LlamaModel(model_config=model_config,
|
||||||
|
start_layer_id=start_layer_id,
|
||||||
|
prefix="model")
|
||||||
|
|
||||||
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.config.draft_vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
org_num_embeddings=self.config.draft_vocab_size,
|
||||||
|
padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
|
||||||
|
prefix="")
|
||||||
|
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||||
|
scale=logit_scale)
|
||||||
|
self.draft_id_to_target_id = nn.Parameter(
|
||||||
|
torch.zeros((self.config.draft_vocab_size),
|
||||||
|
dtype=torch.long).type(torch.LongTensor),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
||||||
|
targets = base + self.draft_id_to_target_id
|
||||||
|
logits_new = logits.new_full((
|
||||||
|
logits.shape[0],
|
||||||
|
self.config.vocab_size,
|
||||||
|
), float('-inf'))
|
||||||
|
logits_new[:, targets] = logits
|
||||||
|
return logits_new
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = {}
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "t2d" in name:
|
||||||
|
continue
|
||||||
|
if "d2t" in name:
|
||||||
|
name = name.replace("d2t", "draft_id_to_target_id")
|
||||||
|
elif "lm_head" not in name:
|
||||||
|
name = "model." + name
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
|
return loader.load_weights(model_weights.items())
|
||||||
@ -214,6 +214,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
_SPECULATIVE_DECODING_MODELS = {
|
_SPECULATIVE_DECODING_MODELS = {
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||||
|
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
|
|||||||
@ -126,7 +126,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||||
if speculative_config:
|
if speculative_config:
|
||||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||||
if speculative_config.method == "eagle":
|
if speculative_config.method in ("eagle", "eagle3"):
|
||||||
self.num_lookahead_tokens = self.num_spec_tokens
|
self.num_lookahead_tokens = self.num_spec_tokens
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
|
|||||||
@ -6,12 +6,16 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
|
||||||
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
@ -87,12 +91,12 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states_logits, hidden_states_fwd = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
hidden_states=target_hidden_states,
|
hidden_states=target_hidden_states,
|
||||||
positions=target_positions,
|
positions=target_positions,
|
||||||
)
|
)
|
||||||
sample_hidden_states = hidden_states[last_token_indices]
|
sample_hidden_states = hidden_states_logits[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|
||||||
@ -105,7 +109,7 @@ class EagleProposer:
|
|||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
hidden_states = sample_hidden_states
|
hidden_states = hidden_states_fwd[last_token_indices]
|
||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
@ -151,12 +155,12 @@ class EagleProposer:
|
|||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states_logits, hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
positions=clamped_positions,
|
positions=clamped_positions,
|
||||||
)
|
)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states_logits, None)
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
@ -221,15 +225,28 @@ class EagleProposer:
|
|||||||
with set_default_torch_dtype(
|
with set_default_torch_dtype(
|
||||||
draft_model_config.dtype), set_current_vllm_config(
|
draft_model_config.dtype), set_current_vllm_config(
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
self.model = EagleLlamaForCausalLM(
|
if self.vllm_config.speculative_config.method == "eagle":
|
||||||
model_config=draft_model_config,
|
self.model = EagleLlamaForCausalLM(
|
||||||
start_layer_id=target_layer_num).to(target_device)
|
model_config=draft_model_config,
|
||||||
|
start_layer_id=target_layer_num).to(target_device)
|
||||||
|
else:
|
||||||
|
assert self.vllm_config.speculative_config.method == "eagle3"
|
||||||
|
self.model = Eagle3LlamaForCausalLM(
|
||||||
|
model_config=draft_model_config,
|
||||||
|
start_layer_id=target_layer_num).to(target_device)
|
||||||
|
|
||||||
self.model.load_weights(
|
loaded_weights = self.model.load_weights(
|
||||||
loader.get_all_weights(
|
loader.get_all_weights(
|
||||||
self.vllm_config.speculative_config.draft_model_config,
|
self.vllm_config.speculative_config.draft_model_config,
|
||||||
self.model))
|
self.model))
|
||||||
self.model.lm_head = target_model.lm_head
|
if self.vllm_config.speculative_config.method == "eagle3":
|
||||||
|
if "model.embed_tokens.weight" not in loaded_weights:
|
||||||
|
logger.info(
|
||||||
|
"Loading EAGLE embedding weights from the target model.")
|
||||||
|
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||||
|
else:
|
||||||
|
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||||
|
self.model.lm_head = target_model.lm_head
|
||||||
|
|
||||||
|
|
||||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||||
|
|||||||
@ -165,14 +165,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
self.use_spec_decode = False
|
self.use_spec_decode = False
|
||||||
|
self.use_aux_hidden_state_outputs = False
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
self.use_spec_decode = True
|
self.use_spec_decode = True
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
if self.speculative_config.method == "ngram":
|
if self.speculative_config.method == "ngram":
|
||||||
self.drafter = NgramProposer(self.vllm_config)
|
self.drafter = NgramProposer(self.vllm_config)
|
||||||
elif self.speculative_config.method == "eagle":
|
elif self.speculative_config.method == "eagle" or \
|
||||||
|
self.speculative_config.method == "eagle3":
|
||||||
self.drafter = EagleProposer(self.vllm_config,
|
self.drafter = EagleProposer(self.vllm_config,
|
||||||
self.device) # type: ignore
|
self.device) # type: ignore
|
||||||
|
if self.speculative_config.method == "eagle3":
|
||||||
|
self.use_aux_hidden_state_outputs = True
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown speculative decoding method: "
|
raise ValueError("Unknown speculative decoding method: "
|
||||||
f"{self.speculative_config.method}")
|
f"{self.speculative_config.method}")
|
||||||
@ -1079,12 +1083,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Run the decoder.
|
# Run the decoder.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
hidden_states, aux_hidden_states = output
|
||||||
|
else:
|
||||||
|
hidden_states = output
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -1182,7 +1192,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert isinstance(self.drafter, NgramProposer)
|
assert isinstance(self.drafter, NgramProposer)
|
||||||
spec_token_ids = self.generate_draft_token_ids(
|
spec_token_ids = self.generate_draft_token_ids(
|
||||||
valid_sampled_token_ids, sampling_metadata)
|
valid_sampled_token_ids, sampling_metadata)
|
||||||
elif self.speculative_config.method == "eagle":
|
elif self.speculative_config.method == "eagle" or \
|
||||||
|
self.speculative_config.method == "eagle3":
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
# TODO(woosuk): Refactor the loop.
|
# TODO(woosuk): Refactor the loop.
|
||||||
next_token_ids: list[int] = []
|
next_token_ids: list[int] = []
|
||||||
@ -1210,7 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# not include padding.
|
# not include padding.
|
||||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = [
|
||||||
|
h[:num_scheduled_tokens] for h in aux_hidden_states
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping
|
target_slot_mapping = attn_metadata.slot_mapping
|
||||||
cu_num_tokens = attn_metadata.query_start_loc
|
cu_num_tokens = attn_metadata.query_start_loc
|
||||||
else:
|
else:
|
||||||
@ -1231,9 +1247,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
target_token_ids = self.input_ids[token_indices]
|
target_token_ids = self.input_ids[token_indices]
|
||||||
target_positions = positions[token_indices]
|
target_positions = positions[token_indices]
|
||||||
target_hidden_states = hidden_states[token_indices]
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = [
|
||||||
|
h[token_indices] for h in aux_hidden_states
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[token_indices]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||||
|
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@ -1311,6 +1334,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if hasattr(self, "drafter"):
|
if hasattr(self, "drafter"):
|
||||||
logger.info("Loading drafter model...")
|
logger.info("Loading drafter model...")
|
||||||
self.drafter.load_model(self.model)
|
self.drafter.load_model(self.model)
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
self.model.set_aux_hidden_state_layers(
|
||||||
|
self.model.get_eagle3_aux_hidden_state_layers())
|
||||||
time_after_load = time.perf_counter()
|
time_after_load = time.perf_counter()
|
||||||
self.model_memory_usage = m.consumed_memory
|
self.model_memory_usage = m.consumed_memory
|
||||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||||
@ -1463,12 +1489,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with set_forward_context(None,
|
with set_forward_context(None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
hidden_states = model(
|
outputs = model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
hidden_states, _ = outputs
|
||||||
|
else:
|
||||||
|
hidden_states = outputs
|
||||||
|
|
||||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||||
return hidden_states[logit_indices]
|
return hidden_states[logit_indices]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user