mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:37:23 +08:00
[Model] Support Qwen3-VL Model Series (#24727)
Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Huang Jie <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: 松灵 <26085463+wulipc@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
5801e49776
commit
0f7acdd73c
@ -661,6 +661,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
|
||||
@ -1437,6 +1437,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
|
||||
# Qwen3-VL-Dense
|
||||
def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|image_pad|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Qwen3-VL-MOE
|
||||
def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|image_pad|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|video_pad|>"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# R-4B
|
||||
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1645,6 +1719,8 @@ model_example_map = {
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"qwen3_vl": run_qwen3_vl,
|
||||
"qwen3_vl_moe": run_qwen3_vl_moe,
|
||||
"rvl": run_r_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
@ -1658,6 +1734,8 @@ MODELS_NEED_VIDEO_METADATA = [
|
||||
"glm4_1v",
|
||||
"glm4_5v",
|
||||
"glm4_5v_fp8",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
"""
|
||||
# Ensure video metadata is included
|
||||
if "video" in mm_data:
|
||||
# GLM4.1V doesn't support multiple videos
|
||||
video = mm_data["video"]
|
||||
num_frames = len(video)
|
||||
mm_data["video"] = (video, {
|
||||
@ -44,6 +45,34 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
return mm_data
|
||||
|
||||
|
||||
def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
"""
|
||||
Patch the multimodal data for Qwen3-VL model.
|
||||
"""
|
||||
|
||||
def create_metadata(frames: np.ndarray):
|
||||
num_frames = len(frames)
|
||||
return {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": 2.0,
|
||||
"duration": num_frames / 2.0,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
"do_sample_frames": True,
|
||||
}
|
||||
|
||||
# Ensure video metadata is included
|
||||
if "video" in mm_data:
|
||||
video = mm_data["video"]
|
||||
if isinstance(video, list):
|
||||
# multiple videos
|
||||
mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
|
||||
else:
|
||||
# single video
|
||||
mm_data["video"] = (video, create_metadata(video))
|
||||
return mm_data
|
||||
|
||||
|
||||
def _test_processing_correctness(
|
||||
model_id_or_arch: str,
|
||||
hit_rate: float,
|
||||
@ -182,8 +211,10 @@ _IGNORE_MM_KEYS = {
|
||||
}
|
||||
|
||||
MM_DATA_PATCHES = {
|
||||
# GLM4.1V requires video metadata to be included in the input
|
||||
# GLM4.1V and Qwen3-VL requires video metadata to be included in the input
|
||||
"glm4v": glm4_1v_patch_mm_data,
|
||||
"qwen3_vl": qwen3_vl_patch_mm_data,
|
||||
"qwen3_vl_moe": qwen3_vl_patch_mm_data,
|
||||
}
|
||||
|
||||
|
||||
@ -326,6 +357,8 @@ def _test_processing_correctness_one(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"Qwen/Qwen2.5-Omni-3B",
|
||||
"Qwen/Qwen3-VL-4B-Instruct",
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
"YannQi/R-4B",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
|
||||
@ -557,6 +557,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_model_len=4096),
|
||||
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
|
||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
|
||||
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57"), # noqa: E501
|
||||
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57"),
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
|
||||
trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
|
||||
@ -103,6 +103,8 @@ def get_rope(
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved",
|
||||
False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
|
||||
@ -177,6 +177,18 @@ def triton_mrope(
|
||||
return q, k
|
||||
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor,
|
||||
mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
|
||||
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
|
||||
return x_t
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
@ -189,6 +201,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[list[int]] = None,
|
||||
mrope_interleaved: Optional[bool] = False,
|
||||
) -> None:
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
@ -198,6 +211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
base, is_neox_style, dtype)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
self.mrope_interleaved = mrope_interleaved
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
@ -225,17 +239,20 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
cos = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
if self.mrope_interleaved:
|
||||
cos = apply_interleaved_rope(cos, self.mrope_section)
|
||||
sin = apply_interleaved_rope(sin, self.mrope_section)
|
||||
else:
|
||||
cos = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
@ -265,6 +282,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
if self.mrope_interleaved:
|
||||
# TODO: add triton implementation to support mrope-interleaved
|
||||
return self.forward_native(positions, query, key)
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -388,6 +409,15 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||
return cls._qwen3vl_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
|
||||
return cls._ernie_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
@ -526,6 +556,98 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _qwen3vl_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
|
||||
for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _ernie_get_input_positions_tensor(
|
||||
cls,
|
||||
|
||||
@ -285,7 +285,7 @@ class Qwen2Model(nn.Module):
|
||||
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ from .vision import get_vit_attn_backend
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
_MAX_FRAMES_PER_VIDEO = 600
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
@ -378,7 +378,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
1478
vllm/model_executor/models/qwen3_vl.py
Normal file
1478
vllm/model_executor/models/qwen3_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
344
vllm/model_executor/models/qwen3_vl_moe.py
Normal file
344
vllm/model_executor/models/qwen3_vl_moe.py
Normal file
@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
|
||||
Qwen3VLMoeConfig)
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen3VLMoeConfig)
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
# the same shape as input_embeds
|
||||
"deepstack_input_embeds": 0
|
||||
})
|
||||
class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert self.start_layer >= len(
|
||||
vllm_config.model_config.hf_config.vision_config.
|
||||
deepstack_visual_indexes), (
|
||||
"start_layer should be greater than or equal to "
|
||||
"len(deepstack_visual_indexes)")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None and \
|
||||
layer_idx in range(0, len(deepstack_input_embeds)):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[
|
||||
f"deepstack_input_embeds_{layer_idx}"]
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_fused_expert_weights(self, name: str, params_dict: dict,
|
||||
loaded_weight: torch.Tensor, shard_id: str,
|
||||
num_experts: int):
|
||||
param = params_dict[name]
|
||||
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||
for expert_id in range(num_experts):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
success = weight_loader(param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=True)
|
||||
if not success:
|
||||
return False
|
||||
return True
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
|
||||
".v_scale", "_v_scale", ".weight_scale",
|
||||
"_weight_scale", ".input_scale", "_input_scale")
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
is_fused_expert = False
|
||||
fused_expert_params_mapping = [
|
||||
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||
]
|
||||
num_experts = self.config.num_experts
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if ("experts.gate_up_proj" in name
|
||||
or "experts.down_proj" in name):
|
||||
is_fused_expert = True
|
||||
expert_params_mapping = fused_expert_params_mapping
|
||||
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1,
|
||||
-2) # no bias
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
success_w1 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[0],
|
||||
"w1", num_experts)
|
||||
success_w3 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[1],
|
||||
"w3", num_experts)
|
||||
success = success_w1 and success_w3
|
||||
else:
|
||||
# down_proj
|
||||
success = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight,
|
||||
shard_id, num_experts)
|
||||
else:
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models
|
||||
if name_mapped.endswith(
|
||||
ignore_suffixes
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(
|
||||
ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_kv_scale_name,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3MoeForCausalLM, self).__init__()
|
||||
self.config = vllm_config.model_config.hf_config.text_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3VLForConditionalGeneration, self).__init__()
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"language_model"))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
self.use_deepstack = hasattr(config.vision_config,
|
||||
'deepstack_visual_indexes')
|
||||
self.deepstack_num_level = len(
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
] if self.use_deepstack else None
|
||||
@ -259,11 +259,13 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501
|
||||
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
|
||||
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
|
||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -156,7 +156,7 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": original_fps,
|
||||
"fps": num_frames / duration,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user