mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:35:01 +08:00
[Model][VLM] Add Qwen2-VL model support (#7905)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
cea95dfb94
commit
3b7fea770f
@ -252,6 +252,11 @@ Multimodal Language Models
|
||||
- Image\ :sup:`E`
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL (see note)
|
||||
- Image\ :sup:`+` / Video\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
* - :code:`UltravoxModel`
|
||||
- Ultravox
|
||||
- Audio\ :sup:`E+`
|
||||
@ -265,15 +270,14 @@ Multimodal Language Models
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||
.. note::
|
||||
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||
This can be installed by running the following command:
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
||||
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
|
||||
@ -179,6 +179,23 @@ def run_qwen_vl(question):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(question):
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
@ -191,6 +208,7 @@ model_example_map = {
|
||||
"blip-2": run_blip2,
|
||||
"internvl_chat": run_internvl,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ by the model.
|
||||
from argparse import Namespace
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
return llm, prompt, stop_token_ids, None
|
||||
|
||||
|
||||
def load_internvl(question, image_urls: List[str]):
|
||||
@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
return llm, prompt, stop_token_ids, None
|
||||
|
||||
|
||||
def load_qwen2_vl(question, image_urls: List[str]):
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||
'be automatically resized. You can enable this functionality by '
|
||||
'`pip install qwen-vl-utils`.')
|
||||
process_vision_info = None
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=5,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
stop_token_ids = None
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages)
|
||||
|
||||
return llm, prompt, stop_token_ids, image_data
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: List[str]):
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question,
|
||||
image_urls)
|
||||
llm, prompt, stop_token_ids, image_data = model_example_map[model](
|
||||
question, image_urls)
|
||||
if image_data is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [fetch_image(url) for url in image_urls]
|
||||
"image": image_data
|
||||
},
|
||||
},
|
||||
sampling_params=sampling_params)
|
||||
@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):
|
||||
|
||||
|
||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
|
||||
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
|
||||
@ -28,3 +28,4 @@ importlib_metadata
|
||||
mistral_common >= 1.3.4
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
from vllm.model_executor.models import _MODELS, ModelRegistry
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||
def test_registry_imports(model_cls):
|
||||
if (model_cls == "Qwen2VLForConditionalGeneration"
|
||||
and transformers.__version__ < "4.45"):
|
||||
pytest.skip("Waiting for next transformers release")
|
||||
|
||||
# Ensure all model classes can be imported successfully
|
||||
ModelRegistry.resolve_model_cls([model_cls])
|
||||
|
||||
@ -773,7 +773,7 @@ class LoadConfig:
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
@ -1733,8 +1733,11 @@ def _get_and_verify_max_len(
|
||||
"with rope_scaling. Please raise an issue so we can "
|
||||
"investigate.")
|
||||
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_type == "mrope":
|
||||
scaling_factor = 1
|
||||
else:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_type == "yarn":
|
||||
derived_max_model_len = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
|
||||
@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio"]
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@ -158,12 +158,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
|
||||
@ -712,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
return new_freqs
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
cos = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def get_input_positions(
|
||||
input_tokens: List[int],
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
context_len: int = 0,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
if isinstance(video_grid_thw, torch.Tensor):
|
||||
video_grid_thw = video_grid_thw.tolist()
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:]
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
|
||||
return llm_positions.tolist(), mrope_position_delta
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> List[List[int]]:
|
||||
return [
|
||||
list(
|
||||
range(context_len + mrope_position_delta,
|
||||
seq_len + mrope_position_delta)) for _ in range(3)
|
||||
]
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@ -752,7 +925,7 @@ def get_rope(
|
||||
# The correct one should be "longrope" but keep "su" here
|
||||
# for backward compatible
|
||||
if scaling_type not in {"su", "longrope"}:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
scaling_factor = rope_scaling.get("factor", 1.0)
|
||||
if scaling_type == "llama3":
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
@ -816,6 +989,16 @@ def get_rope(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "mrope":
|
||||
return MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
|
||||
@ -53,6 +53,8 @@ _GENERATION_MODELS = {
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration":
|
||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
@ -90,6 +92,8 @@ _MULTIMODAL_MODELS = {
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||
"Qwen2VLForConditionalGeneration"),
|
||||
}
|
||||
_CONDITIONAL_GENERATION_MODELS = {
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
|
||||
1088
vllm/model_executor/models/qwen2_vl.py
Normal file
1088
vllm/model_executor/models/qwen2_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
keys = inputs_list[0].keys()
|
||||
|
||||
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
if inputs.keys() != keys:
|
||||
msg = f"Inputs do not share the same keys ({keys})"
|
||||
raise ValueError(msg)
|
||||
# For models that supports multiple modalities (e.g. Qwen2-VL),
|
||||
# different modalities will return different data keys,
|
||||
# so batch() should skip the same key check.
|
||||
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct,
|
||||
# is called.
|
||||
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
# It is used to compute mrope_position_ids.
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
@ -219,6 +222,14 @@ class SequenceData(msgspec.Struct,
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
return self._output_token_ids
|
||||
|
||||
@property
|
||||
def mrope_position_delta(self) -> Optional[int]:
|
||||
return self._mrope_position_delta
|
||||
|
||||
@mrope_position_delta.setter
|
||||
def mrope_position_delta(self, new_mrope_position_delta):
|
||||
self._mrope_position_delta = new_mrope_position_delta
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
|
||||
37
vllm/transformers_utils/processor.py
Normal file
37
vllm/transformers_utils/processor.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import cast
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Gets a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the processor. If the processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return cast(ProcessorMixin, processor)
|
||||
@ -30,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
@ -181,6 +182,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
@ -206,6 +208,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
seq_lens: Optional[List[int]] = None,
|
||||
@ -266,6 +269,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_positions[seq_id].clear()
|
||||
|
||||
self.mrope_input_positions = None
|
||||
|
||||
if seq_lens:
|
||||
self.seq_lens = seq_lens
|
||||
else:
|
||||
@ -327,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
@ -357,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
self.query_lens = [0] * self.n_seqs
|
||||
@ -493,6 +500,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.query_lens[
|
||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
||||
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
if inter_data.mrope_input_positions is None:
|
||||
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||
|
||||
inter_data.mrope_input_positions[
|
||||
seq_idx] = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
|
||||
def _compute_for_prefix_cache_hit(
|
||||
self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
@ -636,6 +654,40 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
inter_data.multi_modal_inputs = mm_kwargs
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
if self.runner.model_is_mrope:
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
||||
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
|
||||
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||
for seq_idx in range(inter_data.n_seqs):
|
||||
seq_data = seq_group_metadata.seq_data[
|
||||
inter_data.seq_ids[seq_idx]]
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
)
|
||||
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
inter_data.mrope_input_positions[
|
||||
seq_idx] = mrope_input_positions
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Add a sequence group to the builder."""
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
@ -684,10 +736,27 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
|
||||
input_positions = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_positions in inter_data.input_positions:
|
||||
input_positions.extend(cur_input_positions)
|
||||
mrope_input_positions: Optional[List[List[int]]] = None
|
||||
if any(inter_data.mrope_input_positions is not None
|
||||
for inter_data in self.inter_data_list):
|
||||
mrope_input_positions = [[] for _ in range(3)]
|
||||
for idx in range(3):
|
||||
for inter_data in self.inter_data_list:
|
||||
msections = inter_data.mrope_input_positions
|
||||
if msections is None:
|
||||
for _seq_input_positions in inter_data.input_positions:
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_input_positions)
|
||||
else:
|
||||
for _seq_mrope_input_positions in msections:
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_mrope_input_positions[idx])
|
||||
input_positions = None
|
||||
else:
|
||||
input_positions = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_positions in inter_data.input_positions:
|
||||
input_positions.extend(cur_input_positions)
|
||||
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
@ -724,15 +793,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Tokens and positions.
|
||||
if cuda_graph_pad_size:
|
||||
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
assert self.runner.device is not None
|
||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
|
||||
if mrope_input_positions is not None:
|
||||
for idx in range(3):
|
||||
mrope_input_positions[idx].extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
|
||||
torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
else:
|
||||
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions_tensor = async_tensor_h2d(input_positions,
|
||||
torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
# Sequence and query lengths.
|
||||
if cuda_graph_pad_size:
|
||||
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
|
||||
@ -1199,6 +1277,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.list_adapters()
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
@ -1229,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
max_batch_size = self.max_batchsize_to_capture
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
|
||||
if self.model_is_mrope:
|
||||
input_positions = torch.tile(input_positions, (3, 1))
|
||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||
# This is used by draft models such as EAGLE.
|
||||
previous_hidden_states = None
|
||||
@ -1293,7 +1381,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"input_ids":
|
||||
input_tokens[:batch_size],
|
||||
"positions":
|
||||
input_positions[:batch_size],
|
||||
input_positions[..., :batch_size],
|
||||
"hidden_or_intermediate_states":
|
||||
hidden_or_intermediate_states[
|
||||
virtual_engine] # type: ignore
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user