mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:35:24 +08:00
[Model][VLM] Add Qwen2-VL model support (#7905)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
cea95dfb94
commit
3b7fea770f
@ -252,6 +252,11 @@ Multimodal Language Models
|
|||||||
- Image\ :sup:`E`
|
- Image\ :sup:`E`
|
||||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`Qwen2VLForConditionalGeneration`
|
||||||
|
- Qwen2-VL (see note)
|
||||||
|
- Image\ :sup:`+` / Video\ :sup:`+`
|
||||||
|
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||||
|
-
|
||||||
* - :code:`UltravoxModel`
|
* - :code:`UltravoxModel`
|
||||||
- Ultravox
|
- Ultravox
|
||||||
- Audio\ :sup:`E+`
|
- Audio\ :sup:`E+`
|
||||||
@ -265,15 +270,14 @@ Multimodal Language Models
|
|||||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||||
|
|
||||||
For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
.. note::
|
||||||
|
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||||
This can be installed by running the following command:
|
This can be installed by running the following command:
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
||||||
|
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
|
|||||||
@ -179,6 +179,23 @@ def run_qwen_vl(question):
|
|||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen2-VL
|
||||||
|
def run_qwen2_vl(question):
|
||||||
|
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
max_num_seqs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n")
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
"llava-next": run_llava_next,
|
"llava-next": run_llava_next,
|
||||||
@ -191,6 +208,7 @@ model_example_map = {
|
|||||||
"blip-2": run_blip2,
|
"blip-2": run_blip2,
|
||||||
"internvl_chat": run_internvl,
|
"internvl_chat": run_internvl,
|
||||||
"qwen_vl": run_qwen_vl,
|
"qwen_vl": run_qwen_vl,
|
||||||
|
"qwen2_vl": run_qwen2_vl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ by the model.
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
|
|||||||
for i, _ in enumerate(image_urls, start=1))
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids, None
|
||||||
|
|
||||||
|
|
||||||
def load_internvl(question, image_urls: List[str]):
|
def load_internvl(question, image_urls: List[str]):
|
||||||
@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
|
|||||||
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
||||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
return llm, prompt, stop_token_ids
|
|
||||||
|
return llm, prompt, stop_token_ids, None
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen2_vl(question, image_urls: List[str]):
|
||||||
|
try:
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||||
|
'be automatically resized. You can enable this functionality by '
|
||||||
|
'`pip install qwen-vl-utils`.')
|
||||||
|
process_vision_info = None
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
max_num_seqs=5,
|
||||||
|
max_model_len=32768 if process_vision_info is None else 4096,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": question
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
stop_token_ids = None
|
||||||
|
|
||||||
|
if process_vision_info is None:
|
||||||
|
image_data = [fetch_image(url) for url in image_urls]
|
||||||
|
else:
|
||||||
|
image_data, _ = process_vision_info(messages)
|
||||||
|
|
||||||
|
return llm, prompt, stop_token_ids, image_data
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"phi3_v": load_phi3v,
|
"phi3_v": load_phi3v,
|
||||||
"internvl_chat": load_internvl,
|
"internvl_chat": load_internvl,
|
||||||
|
"qwen2_vl": load_qwen2_vl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_generate(model, question: str, image_urls: List[str]):
|
def run_generate(model, question: str, image_urls: List[str]):
|
||||||
llm, prompt, stop_token_ids = model_example_map[model](question,
|
llm, prompt, stop_token_ids, image_data = model_example_map[model](
|
||||||
image_urls)
|
question, image_urls)
|
||||||
|
if image_data is None:
|
||||||
|
image_data = [fetch_image(url) for url in image_urls]
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=128,
|
max_tokens=128,
|
||||||
@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
|
|||||||
{
|
{
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"image": [fetch_image(url) for url in image_urls]
|
"image": image_data
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||||
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
|
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=128,
|
max_tokens=128,
|
||||||
|
|||||||
@ -28,3 +28,4 @@ importlib_metadata
|
|||||||
mistral_common >= 1.3.4
|
mistral_common >= 1.3.4
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
|
einops # Required for Qwen2-VL.
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import transformers
|
||||||
|
|
||||||
from vllm.model_executor.models import _MODELS, ModelRegistry
|
from vllm.model_executor.models import _MODELS, ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_cls", _MODELS)
|
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||||
def test_registry_imports(model_cls):
|
def test_registry_imports(model_cls):
|
||||||
|
if (model_cls == "Qwen2VLForConditionalGeneration"
|
||||||
|
and transformers.__version__ < "4.45"):
|
||||||
|
pytest.skip("Waiting for next transformers release")
|
||||||
|
|
||||||
# Ensure all model classes can be imported successfully
|
# Ensure all model classes can be imported successfully
|
||||||
ModelRegistry.resolve_model_cls([model_cls])
|
ModelRegistry.resolve_model_cls([model_cls])
|
||||||
|
|||||||
@ -1733,6 +1733,9 @@ def _get_and_verify_max_len(
|
|||||||
"with rope_scaling. Please raise an issue so we can "
|
"with rope_scaling. Please raise an issue so we can "
|
||||||
"investigate.")
|
"investigate.")
|
||||||
|
|
||||||
|
if rope_type == "mrope":
|
||||||
|
scaling_factor = 1
|
||||||
|
else:
|
||||||
assert "factor" in rope_scaling
|
assert "factor" in rope_scaling
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if rope_type == "yarn":
|
if rope_type == "yarn":
|
||||||
|
|||||||
@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
|
|||||||
"""The tool calls generated by the model, such as function calls."""
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
|
||||||
ModalityStr = Literal["image", "audio"]
|
ModalityStr = Literal["image", "audio", "video"]
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
@ -158,12 +158,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
hf_config.image_token_index)
|
hf_config.image_token_index)
|
||||||
if model_type in ("chameleon", "internvl_chat"):
|
if model_type in ("chameleon", "internvl_chat"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
|
if model_type == "qwen2_vl":
|
||||||
|
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
|
||||||
raise TypeError(f"Unknown model type: {model_type}")
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
elif modality == "audio":
|
elif modality == "audio":
|
||||||
if model_type == "ultravox":
|
if model_type == "ultravox":
|
||||||
return "<|reserved_special_token_0|>"
|
return "<|reserved_special_token_0|>"
|
||||||
raise TypeError(f"Unknown model type: {model_type}")
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "video":
|
||||||
|
if model_type == "qwen2_vl":
|
||||||
|
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown modality: {modality}")
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
|
|||||||
@ -712,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
return new_freqs
|
return new_freqs
|
||||||
|
|
||||||
|
|
||||||
|
class MRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
mrope_section: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style, dtype)
|
||||||
|
|
||||||
|
self.mrope_section = mrope_section
|
||||||
|
if self.mrope_section:
|
||||||
|
assert sum(self.mrope_section) == rotary_dim // 2
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""PyTorch-native implementation equivalent to forward().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions:
|
||||||
|
[num_tokens,] (text only) or
|
||||||
|
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||||
|
query: [num_tokens, num_heads * head_size]
|
||||||
|
key: [num_tokens, num_kv_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert positions.ndim == 1 or positions.ndim == 2
|
||||||
|
|
||||||
|
num_tokens = positions.shape[-1]
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
if positions.ndim == 2:
|
||||||
|
assert self.mrope_section
|
||||||
|
|
||||||
|
cos = torch.cat([
|
||||||
|
m[i]
|
||||||
|
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
sin = torch.cat([
|
||||||
|
m[i]
|
||||||
|
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
query_rot = query[..., :self.rotary_dim]
|
||||||
|
query_pass = query[..., self.rotary_dim:]
|
||||||
|
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
key_rot = key[..., :self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_input_positions(
|
||||||
|
input_tokens: List[int],
|
||||||
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
image_token_id: int,
|
||||||
|
video_token_id: int,
|
||||||
|
vision_start_token_id: int,
|
||||||
|
vision_end_token_id: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
context_len: int = 0,
|
||||||
|
) -> Tuple[List[List[int]], int]:
|
||||||
|
"""Get mrope input positions and delta value."""
|
||||||
|
|
||||||
|
if isinstance(image_grid_thw, torch.Tensor):
|
||||||
|
image_grid_thw = image_grid_thw.tolist()
|
||||||
|
if isinstance(video_grid_thw, torch.Tensor):
|
||||||
|
video_grid_thw = video_grid_thw.tolist()
|
||||||
|
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens)
|
||||||
|
vision_start_indices = torch.argwhere(
|
||||||
|
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||||
|
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (vision_tokens == video_token_id).sum()
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_videos = image_nums, video_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + video_nums):
|
||||||
|
if image_token_id in input_tokens and remain_images > 0:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if video_token_id in input_tokens and remain_videos > 0:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_index += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
ed = ed_video
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||||
|
t, h // spatial_merge_size, w // spatial_merge_size
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||||
|
-1, llm_grid_h * llm_grid_w).flatten()
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||||
|
llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||||
|
llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
llm_positions = llm_positions[:, context_len:]
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 -
|
||||||
|
len(input_tokens)).item()
|
||||||
|
|
||||||
|
return llm_positions.tolist(), mrope_position_delta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_next_input_positions(
|
||||||
|
mrope_position_delta: int,
|
||||||
|
context_len: int,
|
||||||
|
seq_len: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
return [
|
||||||
|
list(
|
||||||
|
range(context_len + mrope_position_delta,
|
||||||
|
seq_len + mrope_position_delta)) for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -752,7 +925,7 @@ def get_rope(
|
|||||||
# The correct one should be "longrope" but keep "su" here
|
# The correct one should be "longrope" but keep "su" here
|
||||||
# for backward compatible
|
# for backward compatible
|
||||||
if scaling_type not in {"su", "longrope"}:
|
if scaling_type not in {"su", "longrope"}:
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling.get("factor", 1.0)
|
||||||
if scaling_type == "llama3":
|
if scaling_type == "llama3":
|
||||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||||
@ -816,6 +989,16 @@ def get_rope(
|
|||||||
head_size, rotary_dim, max_position, original_max_position,
|
head_size, rotary_dim, max_position, original_max_position,
|
||||||
base, is_neox_style, dtype, short_factor, long_factor,
|
base, is_neox_style, dtype, short_factor, long_factor,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
|
elif scaling_type == "mrope":
|
||||||
|
return MRotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
mrope_section=rope_scaling["mrope_section"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
_ROPE_DICT[key] = rotary_emb
|
_ROPE_DICT[key] = rotary_emb
|
||||||
|
|||||||
@ -53,6 +53,8 @@ _GENERATION_MODELS = {
|
|||||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||||
|
"Qwen2VLForConditionalGeneration":
|
||||||
|
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
@ -90,6 +92,8 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||||
|
"Qwen2VLForConditionalGeneration"),
|
||||||
}
|
}
|
||||||
_CONDITIONAL_GENERATION_MODELS = {
|
_CONDITIONAL_GENERATION_MODELS = {
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
|
|||||||
1088
vllm/model_executor/models/qwen2_vl.py
Normal file
1088
vllm/model_executor/models/qwen2_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
if len(inputs_list) == 0:
|
if len(inputs_list) == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
keys = inputs_list[0].keys()
|
|
||||||
|
|
||||||
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
||||||
|
|
||||||
for inputs in inputs_list:
|
for inputs in inputs_list:
|
||||||
if inputs.keys() != keys:
|
# For models that supports multiple modalities (e.g. Qwen2-VL),
|
||||||
msg = f"Inputs do not share the same keys ({keys})"
|
# different modalities will return different data keys,
|
||||||
raise ValueError(msg)
|
# so batch() should skip the same key check.
|
||||||
|
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
item_lists[k].append(v)
|
item_lists[k].append(v)
|
||||||
|
|||||||
@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct,
|
|||||||
# is called.
|
# is called.
|
||||||
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
||||||
|
|
||||||
|
# It is used to compute mrope_position_ids.
|
||||||
|
_mrope_position_delta: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self._prompt_token_ids.typecode == "l"
|
assert self._prompt_token_ids.typecode == "l"
|
||||||
assert self._output_token_ids.typecode == "l"
|
assert self._output_token_ids.typecode == "l"
|
||||||
@ -219,6 +222,14 @@ class SequenceData(msgspec.Struct,
|
|||||||
assert isinstance(self._output_token_ids, array)
|
assert isinstance(self._output_token_ids, array)
|
||||||
return self._output_token_ids
|
return self._output_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mrope_position_delta(self) -> Optional[int]:
|
||||||
|
return self._mrope_position_delta
|
||||||
|
|
||||||
|
@mrope_position_delta.setter
|
||||||
|
def mrope_position_delta(self, new_mrope_position_delta):
|
||||||
|
self._mrope_position_delta = new_mrope_position_delta
|
||||||
|
|
||||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||||
self._output_token_ids.append(token_id)
|
self._output_token_ids.append(token_id)
|
||||||
self._new_appended_tokens.append(token_id)
|
self._new_appended_tokens.append(token_id)
|
||||||
|
|||||||
37
vllm/transformers_utils/processor.py
Normal file
37
vllm/transformers_utils/processor.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
|
def get_processor(
|
||||||
|
processor_name: str,
|
||||||
|
*args,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Gets a processor for the given model name via HuggingFace."""
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from transformers.processing_utils import ProcessorMixin
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
processor_name,
|
||||||
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the error pertains to the processor class not existing or not
|
||||||
|
# currently being imported, suggest using the --trust-remote-code flag.
|
||||||
|
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
|
||||||
|
if not trust_remote_code:
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the processor. If the processor is "
|
||||||
|
"a custom processor not yet available in the HuggingFace "
|
||||||
|
"transformers library, consider setting "
|
||||||
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return cast(ProcessorMixin, processor)
|
||||||
@ -30,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
@ -181,6 +182,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
def simple_reinit(self):
|
def simple_reinit(self):
|
||||||
self.input_tokens[0].clear() # type: ignore
|
self.input_tokens[0].clear() # type: ignore
|
||||||
self.input_positions[0].clear() # type: ignore
|
self.input_positions[0].clear() # type: ignore
|
||||||
|
self.mrope_input_positions = None # type: ignore
|
||||||
self.seq_lens[0] = 0 # type: ignore
|
self.seq_lens[0] = 0 # type: ignore
|
||||||
self.orig_seq_lens[0] = 0 # type: ignore
|
self.orig_seq_lens[0] = 0 # type: ignore
|
||||||
self.query_lens[0] = 0 # type: ignore
|
self.query_lens[0] = 0 # type: ignore
|
||||||
@ -206,6 +208,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
# Input tokens and positions.
|
# Input tokens and positions.
|
||||||
input_tokens: Optional[List[List[int]]] = None,
|
input_tokens: Optional[List[List[int]]] = None,
|
||||||
input_positions: Optional[List[List[int]]] = None,
|
input_positions: Optional[List[List[int]]] = None,
|
||||||
|
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||||
|
|
||||||
# The sequence length (may be capped to the sliding window).
|
# The sequence length (may be capped to the sliding window).
|
||||||
seq_lens: Optional[List[int]] = None,
|
seq_lens: Optional[List[int]] = None,
|
||||||
@ -266,6 +269,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for seq_id in range(len(self.seq_ids)):
|
for seq_id in range(len(self.seq_ids)):
|
||||||
self.input_positions[seq_id].clear()
|
self.input_positions[seq_id].clear()
|
||||||
|
|
||||||
|
self.mrope_input_positions = None
|
||||||
|
|
||||||
if seq_lens:
|
if seq_lens:
|
||||||
self.seq_lens = seq_lens
|
self.seq_lens = seq_lens
|
||||||
else:
|
else:
|
||||||
@ -327,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
else:
|
else:
|
||||||
self.input_tokens = input_tokens or []
|
self.input_tokens = input_tokens or []
|
||||||
self.input_positions = input_positions or []
|
self.input_positions = input_positions or []
|
||||||
|
self.mrope_input_positions = mrope_input_positions or None
|
||||||
self.seq_lens = seq_lens or []
|
self.seq_lens = seq_lens or []
|
||||||
self.orig_seq_lens = orig_seq_lens or []
|
self.orig_seq_lens = orig_seq_lens or []
|
||||||
self.query_lens = query_lens or []
|
self.query_lens = query_lens or []
|
||||||
@ -357,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||||
|
self.mrope_input_positions = None
|
||||||
self.seq_lens = [0] * self.n_seqs
|
self.seq_lens = [0] * self.n_seqs
|
||||||
self.orig_seq_lens = [0] * self.n_seqs
|
self.orig_seq_lens = [0] * self.n_seqs
|
||||||
self.query_lens = [0] * self.n_seqs
|
self.query_lens = [0] * self.n_seqs
|
||||||
@ -493,6 +500,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
inter_data.query_lens[
|
inter_data.query_lens[
|
||||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
||||||
|
|
||||||
|
if seq_data.mrope_position_delta is not None:
|
||||||
|
if inter_data.mrope_input_positions is None:
|
||||||
|
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||||
|
|
||||||
|
inter_data.mrope_input_positions[
|
||||||
|
seq_idx] = MRotaryEmbedding.get_next_input_positions(
|
||||||
|
seq_data.mrope_position_delta,
|
||||||
|
context_len,
|
||||||
|
seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_for_prefix_cache_hit(
|
def _compute_for_prefix_cache_hit(
|
||||||
self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
||||||
seq_group_metadata: SequenceGroupMetadata):
|
seq_group_metadata: SequenceGroupMetadata):
|
||||||
@ -636,6 +654,40 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||||
inter_data.multi_modal_inputs = mm_kwargs
|
inter_data.multi_modal_inputs = mm_kwargs
|
||||||
|
|
||||||
|
# special processing for mrope position deltas.
|
||||||
|
if self.runner.model_is_mrope:
|
||||||
|
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||||
|
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||||
|
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||||
|
"mrope embedding type requires multi-modal input mapper "
|
||||||
|
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
||||||
|
|
||||||
|
hf_config = self.runner.model_config.hf_config
|
||||||
|
|
||||||
|
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||||
|
for seq_idx in range(inter_data.n_seqs):
|
||||||
|
seq_data = seq_group_metadata.seq_data[
|
||||||
|
inter_data.seq_ids[seq_idx]]
|
||||||
|
token_ids = seq_data.get_token_ids()
|
||||||
|
|
||||||
|
mrope_input_positions, mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions(
|
||||||
|
token_ids,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
image_token_id=hf_config.image_token_id,
|
||||||
|
video_token_id=hf_config.video_token_id,
|
||||||
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
|
vision_end_token_id=hf_config.vision_end_token_id,
|
||||||
|
spatial_merge_size=hf_config.vision_config.
|
||||||
|
spatial_merge_size,
|
||||||
|
context_len=inter_data.context_lens[seq_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_data.mrope_position_delta = mrope_position_delta
|
||||||
|
inter_data.mrope_input_positions[
|
||||||
|
seq_idx] = mrope_input_positions
|
||||||
|
|
||||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||||
"""Add a sequence group to the builder."""
|
"""Add a sequence group to the builder."""
|
||||||
seq_ids = seq_group_metadata.seq_data.keys()
|
seq_ids = seq_group_metadata.seq_data.keys()
|
||||||
@ -684,6 +736,23 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
# prefix caching and there is no decode request.
|
# prefix caching and there is no decode request.
|
||||||
return self.model_input_cls()
|
return self.model_input_cls()
|
||||||
|
|
||||||
|
mrope_input_positions: Optional[List[List[int]]] = None
|
||||||
|
if any(inter_data.mrope_input_positions is not None
|
||||||
|
for inter_data in self.inter_data_list):
|
||||||
|
mrope_input_positions = [[] for _ in range(3)]
|
||||||
|
for idx in range(3):
|
||||||
|
for inter_data in self.inter_data_list:
|
||||||
|
msections = inter_data.mrope_input_positions
|
||||||
|
if msections is None:
|
||||||
|
for _seq_input_positions in inter_data.input_positions:
|
||||||
|
mrope_input_positions[idx].extend(
|
||||||
|
_seq_input_positions)
|
||||||
|
else:
|
||||||
|
for _seq_mrope_input_positions in msections:
|
||||||
|
mrope_input_positions[idx].extend(
|
||||||
|
_seq_mrope_input_positions[idx])
|
||||||
|
input_positions = None
|
||||||
|
else:
|
||||||
input_positions = []
|
input_positions = []
|
||||||
for inter_data in self.inter_data_list:
|
for inter_data in self.inter_data_list:
|
||||||
for cur_input_positions in inter_data.input_positions:
|
for cur_input_positions in inter_data.input_positions:
|
||||||
@ -724,15 +793,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
# Tokens and positions.
|
# Tokens and positions.
|
||||||
if cuda_graph_pad_size:
|
if cuda_graph_pad_size:
|
||||||
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
|
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||||
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
|
||||||
assert self.runner.device is not None
|
assert self.runner.device is not None
|
||||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||||
self.runner.device,
|
self.runner.device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
|
if mrope_input_positions is not None:
|
||||||
|
for idx in range(3):
|
||||||
|
mrope_input_positions[idx].extend(
|
||||||
|
itertools.repeat(0, cuda_graph_pad_size))
|
||||||
|
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
|
||||||
|
torch.long,
|
||||||
|
self.runner.device,
|
||||||
|
self.runner.pin_memory)
|
||||||
|
else:
|
||||||
|
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||||
|
input_positions_tensor = async_tensor_h2d(input_positions,
|
||||||
|
torch.long,
|
||||||
self.runner.device,
|
self.runner.device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
|
|
||||||
# Sequence and query lengths.
|
# Sequence and query lengths.
|
||||||
if cuda_graph_pad_size:
|
if cuda_graph_pad_size:
|
||||||
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
|
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
|
||||||
@ -1199,6 +1277,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
raise RuntimeError("PromptAdapter is not enabled.")
|
raise RuntimeError("PromptAdapter is not enabled.")
|
||||||
return self.prompt_adapter_manager.list_adapters()
|
return self.prompt_adapter_manager.list_adapters()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_is_mrope(self) -> bool:
|
||||||
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||||
|
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||||
|
if rope_scaling is None:
|
||||||
|
return False
|
||||||
|
return rope_scaling.get("type", None) == "mrope"
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||||
"""Cuda graph capture a model.
|
"""Cuda graph capture a model.
|
||||||
@ -1229,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
max_batch_size = self.max_batchsize_to_capture
|
max_batch_size = self.max_batchsize_to_capture
|
||||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
|
if self.model_is_mrope:
|
||||||
|
input_positions = torch.tile(input_positions, (3, 1))
|
||||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||||
# This is used by draft models such as EAGLE.
|
# This is used by draft models such as EAGLE.
|
||||||
previous_hidden_states = None
|
previous_hidden_states = None
|
||||||
@ -1293,7 +1381,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"input_ids":
|
"input_ids":
|
||||||
input_tokens[:batch_size],
|
input_tokens[:batch_size],
|
||||||
"positions":
|
"positions":
|
||||||
input_positions[:batch_size],
|
input_positions[..., :batch_size],
|
||||||
"hidden_or_intermediate_states":
|
"hidden_or_intermediate_states":
|
||||||
hidden_or_intermediate_states[
|
hidden_or_intermediate_states[
|
||||||
virtual_engine] # type: ignore
|
virtual_engine] # type: ignore
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user