[Model][VLM] Add Qwen2-VL model support (#7905)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Yang Fan 2024-09-12 00:31:19 +08:00 committed by GitHub
parent cea95dfb94
commit 3b7fea770f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1531 additions and 31 deletions

View File

@ -252,6 +252,11 @@ Multimodal Language Models
- Image\ :sup:`E` - Image\ :sup:`E`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL (see note)
- Image\ :sup:`+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
- Ultravox - Ultravox
- Audio\ :sup:`E+` - Audio\ :sup:`E+`
@ -265,15 +270,14 @@ Multimodal Language Models
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. .. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command: This can be installed by running the following command:
.. code-block:: bash .. code-block:: bash
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
---- ----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.

View File

@ -179,6 +179,23 @@ def run_qwen_vl(question):
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
# Qwen2-VL
def run_qwen2_vl(question):
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = { model_example_map = {
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
@ -191,6 +208,7 @@ model_example_map = {
"blip-2": run_blip2, "blip-2": run_blip2,
"internvl_chat": run_internvl, "internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl, "qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
} }

View File

@ -6,7 +6,7 @@ by the model.
from argparse import Namespace from argparse import Namespace
from typing import List from typing import List
from transformers import AutoTokenizer from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1)) for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids, None
def load_internvl(question, image_urls: List[str]): def load_internvl(question, image_urls: List[str]):
@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
# https://huggingface.co/OpenGVLab/InternVL2-2B#service # https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
return llm, prompt, stop_token_ids, None
def load_qwen2_vl(question, image_urls: List[str]):
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)
return llm, prompt, stop_token_ids, image_data
model_example_map = { model_example_map = {
"phi3_v": load_phi3v, "phi3_v": load_phi3v,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
} }
def run_generate(model, question: str, image_urls: List[str]): def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question, llm, prompt, stop_token_ids, image_data = model_example_map[model](
image_urls) question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=128, max_tokens=128,
@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
{ {
"prompt": prompt, "prompt": prompt,
"multi_modal_data": { "multi_modal_data": {
"image": [fetch_image(url) for url in image_urls] "image": image_data
}, },
}, },
sampling_params=sampling_params) sampling_params=sampling_params)
@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):
def run_chat(model: str, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls) llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=128, max_tokens=128,

View File

@ -28,3 +28,4 @@ importlib_metadata
mistral_common >= 1.3.4 mistral_common >= 1.3.4
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.

View File

@ -1,9 +1,14 @@
import pytest import pytest
import transformers
from vllm.model_executor.models import _MODELS, ModelRegistry from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS) @pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls): def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration"
and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release")
# Ensure all model classes can be imported successfully # Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls([model_cls]) ModelRegistry.resolve_model_cls([model_cls])

View File

@ -773,7 +773,7 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
@ -1733,8 +1733,11 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can " "with rope_scaling. Please raise an issue so we can "
"investigate.") "investigate.")
assert "factor" in rope_scaling if rope_type == "mrope":
scaling_factor = rope_scaling["factor"] scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling[ derived_max_model_len = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]

View File

@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls.""" """The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio"] ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T") _T = TypeVar("_T")
@ -158,12 +158,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")

View File

@ -712,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs return new_freqs
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
]
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@ -752,7 +925,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope"}: if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling.get("factor", 1.0)
if scaling_type == "llama3": if scaling_type == "llama3":
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
@ -816,6 +989,16 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope":
return MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb

View File

@ -53,6 +53,8 @@ _GENERATION_MODELS = {
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
@ -90,6 +92,8 @@ _MULTIMODAL_MODELS = {
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),

File diff suppressed because it is too large Load Diff

View File

@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if len(inputs_list) == 0: if len(inputs_list) == 0:
return {} return {}
keys = inputs_list[0].keys()
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list: for inputs in inputs_list:
if inputs.keys() != keys: # For models that supports multiple modalities (e.g. Qwen2-VL),
msg = f"Inputs do not share the same keys ({keys})" # different modalities will return different data keys,
raise ValueError(msg) # so batch() should skip the same key check.
for k, v in inputs.items(): for k, v in inputs.items():
item_lists[k].append(v) item_lists[k].append(v)

View File

@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct,
# is called. # is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list) _new_appended_tokens: List[int] = msgspec.field(default_factory=list)
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l"
@ -219,6 +222,14 @@ class SequenceData(msgspec.Struct,
assert isinstance(self._output_token_ids, array) assert isinstance(self._output_token_ids, array)
return self._output_token_ids return self._output_token_ids
@property
def mrope_position_delta(self) -> Optional[int]:
return self._mrope_position_delta
@mrope_position_delta.setter
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id) self._new_appended_tokens.append(token_id)

View File

@ -0,0 +1,37 @@
from typing import cast
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
from transformers.processing_utils import ProcessorMixin
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return cast(ProcessorMixin, processor)

View File

@ -30,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -181,6 +182,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self): def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore
@ -206,6 +208,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions. # Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window). # The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None, seq_lens: Optional[List[int]] = None,
@ -266,6 +269,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear() self.input_positions[seq_id].clear()
self.mrope_input_positions = None
if seq_lens: if seq_lens:
self.seq_lens = seq_lens self.seq_lens = seq_lens
else: else:
@ -327,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or [] self.query_lens = query_lens or []
@ -357,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs
@ -493,6 +500,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None:
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
inter_data.mrope_input_positions[
seq_idx] = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
def _compute_for_prefix_cache_hit( def _compute_for_prefix_cache_hit(
self, inter_data: InterDataForSeqGroup, seq_idx: int, self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
@ -636,6 +654,40 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_modal_inputs = mm_kwargs
# special processing for mrope position deltas.
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
hf_config = self.runner.model_config.hf_config
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
for seq_idx in range(inter_data.n_seqs):
seq_data = seq_group_metadata.seq_data[
inter_data.seq_ids[seq_idx]]
token_ids = seq_data.get_token_ids()
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=inter_data.context_lens[seq_idx],
)
seq_data.mrope_position_delta = mrope_position_delta
inter_data.mrope_input_positions[
seq_idx] = mrope_input_positions
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder.""" """Add a sequence group to the builder."""
seq_ids = seq_group_metadata.seq_data.keys() seq_ids = seq_group_metadata.seq_data.keys()
@ -684,10 +736,27 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# prefix caching and there is no decode request. # prefix caching and there is no decode request.
return self.model_input_cls() return self.model_input_cls()
input_positions = [] mrope_input_positions: Optional[List[List[int]]] = None
for inter_data in self.inter_data_list: if any(inter_data.mrope_input_positions is not None
for cur_input_positions in inter_data.input_positions: for inter_data in self.inter_data_list):
input_positions.extend(cur_input_positions) mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
for inter_data in self.inter_data_list:
msections = inter_data.mrope_input_positions
if msections is None:
for _seq_input_positions in inter_data.input_positions:
mrope_input_positions[idx].extend(
_seq_input_positions)
else:
for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
_seq_mrope_input_positions[idx])
input_positions = None
else:
input_positions = []
for inter_data in self.inter_data_list:
for cur_input_positions in inter_data.input_positions:
input_positions.extend(cur_input_positions)
seq_lens = [] seq_lens = []
max_decode_seq_len = 0 max_decode_seq_len = 0
@ -724,15 +793,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions. # Tokens and positions.
if cuda_graph_pad_size: if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
input_positions_tensor = async_tensor_h2d(input_positions, torch.long, if mrope_input_positions is not None:
self.runner.device, for idx in range(3):
self.runner.pin_memory) mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
# Sequence and query lengths. # Sequence and query lengths.
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
@ -1199,6 +1277,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
raise RuntimeError("PromptAdapter is not enabled.") raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters() return self.prompt_adapter_manager.list_adapters()
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
@ -1229,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size = self.max_batchsize_to_capture max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
if self.model_is_mrope:
input_positions = torch.tile(input_positions, (3, 1))
# Prepare dummy previous_hidden_states only if needed by the model. # Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE. # This is used by draft models such as EAGLE.
previous_hidden_states = None previous_hidden_states = None
@ -1293,7 +1381,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"input_ids": "input_ids":
input_tokens[:batch_size], input_tokens[:batch_size],
"positions": "positions":
input_positions[:batch_size], input_positions[..., :batch_size],
"hidden_or_intermediate_states": "hidden_or_intermediate_states":
hidden_or_intermediate_states[ hidden_or_intermediate_states[
virtual_engine] # type: ignore virtual_engine] # type: ignore