mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[Meta] Official Eagle mm support, first enablement on llama4 (#20788)
Signed-off-by: morgendave <morgendave@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
parent
53c21e492e
commit
9e0726e5bf
@ -13,6 +13,38 @@ except ImportError:
|
|||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
QUESTION = "What is the content of each image?"
|
||||||
|
IMAGE_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_mm_prompts(num_prompts):
|
||||||
|
prompts = []
|
||||||
|
for url in IMAGE_URLS:
|
||||||
|
prompts.append(
|
||||||
|
[
|
||||||
|
{"type": "image_url", "image_url": {"url": url}},
|
||||||
|
{"type": "text", "text": QUESTION},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if num_prompts > len(IMAGE_URLS):
|
||||||
|
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
|
||||||
|
|
||||||
|
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
add_dataset_parser(parser)
|
add_dataset_parser(parser)
|
||||||
@ -35,6 +67,7 @@ def parse_args():
|
|||||||
parser.add_argument("--output-len", type=int, default=256)
|
parser.add_argument("--output-len", type=int, default=256)
|
||||||
parser.add_argument("--model-dir", type=str, default=None)
|
parser.add_argument("--model-dir", type=str, default=None)
|
||||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||||
|
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -44,14 +77,26 @@ def main():
|
|||||||
|
|
||||||
model_dir = args.model_dir
|
model_dir = args.model_dir
|
||||||
if args.model_dir is None:
|
if args.model_dir is None:
|
||||||
|
if args.custom_mm_prompts:
|
||||||
|
raise ValueError(
|
||||||
|
"custom_mm_prompts requires mm based models"
|
||||||
|
"default llama3.1-8b-instruct is not mm based"
|
||||||
|
"please specify model_dir to give a mm based model"
|
||||||
|
)
|
||||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
args.custom_skip_chat_template = True
|
||||||
|
|
||||||
prompts = get_samples(args, tokenizer)
|
if not args.custom_mm_prompts:
|
||||||
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
prompts = get_samples(args, tokenizer)
|
||||||
prompt_ids = [
|
# add_special_tokens is False to avoid adding bos twice
|
||||||
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
# when using chat templates
|
||||||
]
|
prompt_ids = [
|
||||||
|
tokenizer.encode(prompt.prompt, add_special_tokens=False)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
prompts = get_custom_mm_prompts(args.num_prompts)
|
||||||
|
|
||||||
if args.method == "eagle" or args.method == "eagle3":
|
if args.method == "eagle" or args.method == "eagle3":
|
||||||
eagle_dir = args.eagle_dir
|
eagle_dir = args.eagle_dir
|
||||||
@ -85,10 +130,17 @@ def main():
|
|||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
|
limit_mm_per_prompt={"image": 5},
|
||||||
|
disable_chunked_mm_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
if not args.custom_mm_prompts:
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||||
|
|
||||||
# print the generated text
|
# print the generated text
|
||||||
if args.print_output:
|
if args.print_output:
|
||||||
|
|||||||
@ -3,29 +3,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||||
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def get_test_prompts(mm_enabled: bool):
|
||||||
def test_prompts():
|
|
||||||
prompt_types = ["repeat", "sentence"]
|
prompt_types = ["repeat", "sentence"]
|
||||||
|
if mm_enabled:
|
||||||
|
prompt_types.append("mm")
|
||||||
num_prompts = 100
|
num_prompts = 100
|
||||||
prompts = []
|
prompts = []
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||||
|
print(f"Prompt types: {random_prompt_type_choices}")
|
||||||
|
|
||||||
# Generate a mixed batch of prompts, some of which can be easily
|
# Generate a mixed batch of prompts, some of which can be easily
|
||||||
# predicted by n-gram matching and some which likely cannot.
|
# predicted by n-gram matching and some which likely cannot.
|
||||||
for kind in random_prompt_type_choices:
|
for kind in random_prompt_type_choices:
|
||||||
word_choices = ["test", "temp", "hello", "where"]
|
word_choices = ["test", "temp", "hello", "where"]
|
||||||
word = random.choice(word_choices)
|
word = random.choice(word_choices)
|
||||||
|
prompt: Union[str, list[dict[str, Any]]] = ""
|
||||||
if kind == "repeat":
|
if kind == "repeat":
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
please repeat the word '{word}' 10 times.
|
please repeat the word '{word}' 10 times.
|
||||||
@ -38,6 +43,21 @@ def test_prompts():
|
|||||||
uses the word {word} at least once.
|
uses the word {word} at least once.
|
||||||
give no other output than that simple sentence without quotes.
|
give no other output than that simple sentence without quotes.
|
||||||
"""
|
"""
|
||||||
|
elif kind == "mm":
|
||||||
|
placeholders = [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url":
|
||||||
|
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
prompt = [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The meaning of the image is"
|
||||||
|
},
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prompt type: {kind}")
|
raise ValueError(f"Unknown prompt type: {kind}")
|
||||||
prompts.append([{"role": "user", "content": prompt}])
|
prompts.append([{"role": "user", "content": prompt}])
|
||||||
@ -57,7 +77,6 @@ def model_name():
|
|||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
@ -67,6 +86,7 @@ def test_ngram_correctness(
|
|||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
test_prompts = get_test_prompts(mm_enabled=False)
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
@ -103,23 +123,32 @@ def test_ngram_correctness(
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_setup", [
|
@pytest.mark.parametrize(
|
||||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
["model_setup", "mm_enabled"], [
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
pytest.param(
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
pytest.param(
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
],
|
False,
|
||||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
|
pytest.param(
|
||||||
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
|
True,
|
||||||
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
|
],
|
||||||
|
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_setup: tuple[str, str, str, int],
|
model_setup: tuple[str, str, str, int],
|
||||||
|
mm_enabled: bool,
|
||||||
):
|
):
|
||||||
|
# Generate test prompts inside the function instead of using fixture
|
||||||
|
test_prompts = get_test_prompts(mm_enabled)
|
||||||
'''
|
'''
|
||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
should be the same when using eagle speculative decoding.
|
should be the same when using eagle speculative decoding.
|
||||||
|
|||||||
@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layer_idx = extract_layer_index(prefix)
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
|
|||||||
@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||||
Llama4ForCausalLM)
|
Llama4ForCausalLM)
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
|
|||||||
self.norm = RMSNorm(self.config.hidden_size,
|
self.norm = RMSNorm(self.config.hidden_size,
|
||||||
eps=self.config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
input_embeds = self.embed_tokens(input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
hidden_states = self.fc(
|
hidden_states = self.fc(
|
||||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||||
residual = None
|
residual = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> None:
|
torch.Tensor]]) -> None:
|
||||||
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
model_weights[name] = loaded_weight
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
loader.load_weights(model_weights.items())
|
loader.load_weights(model_weights.items())
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
multimodal_embeddings,
|
||||||
|
self.config.image_token_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||||
|
)
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||||
|
)
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -51,6 +53,9 @@ class EagleProposer:
|
|||||||
# hidden size (e.g., Llama 3.3 70B).
|
# hidden size (e.g., Llama 3.3 70B).
|
||||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||||
|
|
||||||
|
self.is_multimodal_model = vllm_config.model_config \
|
||||||
|
.is_multimodal_model
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE and
|
== CompilationLevel.PIECEWISE and
|
||||||
not self.vllm_config.model_config.enforce_eager)
|
not self.vllm_config.model_config.enforce_eager)
|
||||||
@ -76,6 +81,11 @@ class EagleProposer:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
self.inputs_embeds = torch.zeros(
|
||||||
|
(self.max_num_tokens, self.hidden_size),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
@ -88,6 +98,7 @@ class EagleProposer:
|
|||||||
next_token_ids: torch.Tensor,
|
next_token_ids: torch.Tensor,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
@ -128,14 +139,27 @@ class EagleProposer:
|
|||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
|
input_ids,
|
||||||
|
multimodal_embeddings=mm_embeds or None,
|
||||||
|
)
|
||||||
|
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_tokens):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
self.input_ids[:num_input_tokens],
|
input_ids=input_ids,
|
||||||
self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
self.hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method == "deepseek_mtp":
|
if self.method == "deepseek_mtp":
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
@ -218,15 +242,24 @@ class EagleProposer:
|
|||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self.positions[:batch_size] = clamped_positions
|
self.positions[:batch_size] = clamped_positions
|
||||||
self.hidden_states[:batch_size] = hidden_states
|
self.hidden_states[:batch_size] = hidden_states
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
|
self.inputs_embeds[:batch_size] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size):
|
num_tokens=input_batch_size):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
self.input_ids[:input_batch_size],
|
input_ids=input_ids,
|
||||||
self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
self.hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
@ -391,10 +424,18 @@ class EagleProposer:
|
|||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_forward_context(None, self.vllm_config,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
self.model(
|
self.model(
|
||||||
self.input_ids[:num_tokens],
|
input_ids=input_ids,
|
||||||
self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
self.hidden_states[:num_tokens],
|
hidden_states=self.hidden_states[:num_tokens],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_same_kv_cache_group(self,
|
def validate_same_kv_cache_group(self,
|
||||||
|
|||||||
@ -1205,13 +1205,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
def _gather_mm_embeddings(
|
def _gather_mm_embeddings(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
|
shift_computed_tokens: int = 0,
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
mm_embeds: list[torch.Tensor] = []
|
mm_embeds: list[torch.Tensor] = []
|
||||||
for req_id in self.input_batch.req_ids:
|
for req_id in self.input_batch.req_ids:
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = \
|
||||||
|
req_state.num_computed_tokens + shift_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info.offset
|
||||||
@ -1858,6 +1860,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
else:
|
else:
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
|
mm_embeds = None
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||||||
|
shift_computed_tokens=1)
|
||||||
|
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@ -1865,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
next_token_ids=next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
mm_embeds=mm_embeds,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
return spec_token_ids
|
return spec_token_ids
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user