mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:44:55 +08:00
[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
55392bc879
commit
727144bed1
File diff suppressed because it is too large
Load Diff
@ -23,6 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Erine VL model compatible with HuggingFace weights."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
@ -33,7 +34,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import (
|
||||
@ -76,6 +77,7 @@ from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -1271,7 +1273,7 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing
|
||||
dummy_inputs=Ernie4_5_VLDummyInputsBuilder,
|
||||
)
|
||||
class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@ -1388,6 +1390,151 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
else:
|
||||
self.visual_token_mask = None
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for Ernie VL."""
|
||||
|
||||
image_token_id = hf_config.im_patch_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_conv_size = hf_config.spatial_conv_size
|
||||
temporal_conv_size = hf_config.temporal_conv_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]
|
||||
):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_conv_size,
|
||||
w // spatial_conv_size,
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_grid_thw[mm_data_idx][0],
|
||||
video_grid_thw[mm_data_idx][1],
|
||||
video_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t // temporal_conv_size,
|
||||
h // spatial_conv_size,
|
||||
w // spatial_conv_size,
|
||||
)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = (
|
||||
torch.tensor(t_idx)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(1, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(1, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
# https://github.com/zai-org/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
|
||||
import itertools
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
@ -14,7 +15,7 @@ from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
@ -54,6 +55,7 @@ from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -554,7 +556,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder,
|
||||
)
|
||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP):
|
||||
class GLM4VForCausalLM(
|
||||
ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -615,6 +619,150 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, Suppo
|
||||
|
||||
return self.transformer.vision(pixel_values)
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]
|
||||
):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = (
|
||||
torch.tensor(t_idx)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(1, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(1, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP
|
||||
from .keye import (
|
||||
BaseKeyeModule,
|
||||
BaseMultiModalProcessor,
|
||||
@ -493,7 +493,7 @@ class KeyeVL1_5DummyInputsBuilder(
|
||||
dummy_inputs=KeyeVL1_5DummyInputsBuilder,
|
||||
)
|
||||
class KeyeVL1_5ForConditionalGeneration(
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
def _build_projector(
|
||||
self,
|
||||
@ -589,3 +589,143 @@ class KeyeVL1_5ForConditionalGeneration(
|
||||
end = patch_cu_seqlens[idx + 1]
|
||||
new_video_embeds.append(video_embeds[start:end])
|
||||
return tuple(new_video_embeds)
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||
video_grid_thw = video_grid_thw[0]
|
||||
"""Get mrope input positions and delta value (Keye series)."""
|
||||
|
||||
def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
|
||||
"""
|
||||
Split grid_thw along the t dimension.
|
||||
|
||||
Args:
|
||||
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||
|
||||
Returns:
|
||||
List of [1, h, w] rows, repeated t times for each original row.
|
||||
"""
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||
|
||||
if grid_thw.numel() == 0:
|
||||
return []
|
||||
|
||||
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||
return out.tolist()
|
||||
|
||||
video_grid_thw = split_thw(video_grid_thw)
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
image_nums = len(image_grid_thw)
|
||||
frame_nums = len(video_grid_thw)
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_frames = image_nums, frame_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + frame_nums):
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_frames > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_frames -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@ -29,6 +29,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
|
||||
Qwen2_5OmniConfig,
|
||||
@ -45,7 +46,6 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer,
|
||||
@ -93,6 +93,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -101,7 +102,9 @@ from .utils import (
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
split_list_into_ranges,
|
||||
)
|
||||
from .vision import get_llm_pos_ids_for_vision
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@ -412,6 +415,59 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
|
||||
return prompt_ids, mm_placeholders
|
||||
|
||||
@classmethod
|
||||
def omni_get_updates_use_audio_in_video(
|
||||
cls,
|
||||
thinker_config: PretrainedConfig,
|
||||
audio_len: int,
|
||||
video_grid_thw: Union[list[int], torch.Tensor],
|
||||
video_second_per_grid_t: float,
|
||||
) -> list[int]:
|
||||
"""Get video prompt updates when `use_audio_in_video` is True.
|
||||
|
||||
In this case, audio and vision update ids will be split into
|
||||
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
|
||||
|
||||
<|video_bos|><|VIDEO|><|video_eos|> =>
|
||||
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
|
||||
"""
|
||||
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
audio_start_token_id = thinker_config.audio_start_token_id
|
||||
audio_end_token_id = thinker_config.audio_end_token_id
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
grid_t = video_grid_thw[0]
|
||||
grid_h = video_grid_thw[1]
|
||||
grid_w = video_grid_thw[2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
|
||||
).long()
|
||||
t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk)
|
||||
|
||||
updates = [audio_start_token_id]
|
||||
added_audio_len = 0
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = (
|
||||
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
)
|
||||
updates.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||
|
||||
audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len)
|
||||
updates.extend(audio_chunk_size * [audio_token_id])
|
||||
added_audio_len += audio_chunk_size
|
||||
if added_audio_len < audio_len:
|
||||
updates.extend((audio_len - added_audio_len) * [audio_token_id])
|
||||
updates.extend([audio_end_token_id])
|
||||
|
||||
return updates
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -491,7 +547,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
else:
|
||||
video_second_per_grid_t = 1.0
|
||||
|
||||
return MRotaryEmbedding.omni_get_updates_use_audio_in_video(
|
||||
return self.omni_get_updates_use_audio_in_video(
|
||||
thinker_config=thinker_config,
|
||||
audio_len=audio_num_features,
|
||||
video_grid_thw=video_grid_thw,
|
||||
@ -808,6 +864,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -929,6 +986,216 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
|
||||
|
||||
Differences from MRotaryEmbedding:
|
||||
1. Add audio support (and related `audio_feature_lengths`).
|
||||
2. Add `use_audio_in_video` option to read audio from video inputs.
|
||||
In this case, audio and vision position ids will be split into
|
||||
chunks and interleaved.
|
||||
|
||||
Example:
|
||||
|
||||
(V_i are vision position ids, A_i are audio position ids)
|
||||
|
||||
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|
||||
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
|
||||
"""
|
||||
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
|
||||
thinker_config = hf_config.thinker_config
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
audio_start_token_id = thinker_config.audio_start_token_id
|
||||
audio_end_token_id = thinker_config.audio_end_token_id
|
||||
vision_start_token_id = thinker_config.vision_start_token_id
|
||||
vision_end_token_id = thinker_config.vision_end_token_id
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
src_item = input_tokens
|
||||
audio_seqlens = audio_feature_lengths
|
||||
if not second_per_grid_ts:
|
||||
second_per_grid_ts = [1] * video_grid_thw.shape[0]
|
||||
audio_idx = 0
|
||||
video_idx = 0
|
||||
image_idx = 0
|
||||
new_src_item: list[int] = []
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
|
||||
idx = 0
|
||||
while idx < len(src_item):
|
||||
new_src_item_len = len(new_src_item)
|
||||
start_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
|
||||
if use_audio_in_video and idx > 0:
|
||||
if (
|
||||
src_item[idx] == vision_end_token_id
|
||||
and src_item[idx - 1] == audio_end_token_id
|
||||
):
|
||||
# processing the <|audio_eos|> before <|vision_eos|>
|
||||
start_idx -= 1
|
||||
elif (
|
||||
src_item[idx] == audio_start_token_id
|
||||
and src_item[idx - 1] == vision_start_token_id
|
||||
):
|
||||
# processing the <|audio_bos|> after <|vision_eos|>
|
||||
start_idx -= 1
|
||||
new_src_item.append(src_item[idx])
|
||||
llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
elif src_item[idx] == audio_token_id:
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
|
||||
new_src_item.extend([audio_token_id] * place_num)
|
||||
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
audio_idx += 1
|
||||
elif src_item[idx] == image_token_id:
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = image_grid_thw[image_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([image_token_id] * vision_seqlen)
|
||||
image_idx += 1
|
||||
elif src_item[idx] == video_token_id and not use_audio_in_video:
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_seqlen)
|
||||
video_idx += 1
|
||||
else:
|
||||
# read audio from video
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_h = video_grid_thw[video_idx][1]
|
||||
grid_w = video_grid_thw[video_idx][2]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
t_index_split_chunk = split_list_into_ranges(
|
||||
t_index, t_ntoken_per_chunk
|
||||
)
|
||||
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
|
||||
pure_audio_len = place_num - 2
|
||||
added_audio_len = 0
|
||||
audio_llm_pos_ids_list: list[torch.Tensor] = []
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = (
|
||||
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||
vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
|
||||
start_idx,
|
||||
video_idx,
|
||||
spatial_merge_size,
|
||||
t_chunk,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
|
||||
new_src_item.extend(
|
||||
min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
|
||||
* [audio_token_id]
|
||||
)
|
||||
audio_start_idx = (
|
||||
start_idx
|
||||
if len(audio_llm_pos_ids_list) == 0
|
||||
else audio_llm_pos_ids_list[-1][0].item() + 1
|
||||
)
|
||||
if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(
|
||||
min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
).expand(3, -1)
|
||||
+ audio_start_idx
|
||||
).split(1, dim=1)
|
||||
else:
|
||||
audio_llm_pos_ids_list = []
|
||||
added_audio_len += min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
if added_audio_len < pure_audio_len:
|
||||
new_src_item.extend(
|
||||
(pure_audio_len - added_audio_len) * [audio_token_id]
|
||||
)
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
|
||||
+ llm_pos_ids_list[-1].max()
|
||||
+ 1
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
# move to the next token
|
||||
idx += len(new_src_item) - new_src_item_len
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
|
||||
mrope_position_delta = (
|
||||
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
|
||||
)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
|
||||
@ -34,7 +34,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
@ -79,6 +79,7 @@ from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsMultiModalPruning,
|
||||
SupportsPP,
|
||||
@ -1053,6 +1054,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
SupportsQuant,
|
||||
SupportsEagle3,
|
||||
SupportsMultiModalPruning,
|
||||
SupportsMRoPE,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -1073,6 +1075,132 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_videos > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_second_per_grid_t = 1.0
|
||||
if second_per_grid_ts:
|
||||
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
* video_second_per_grid_t
|
||||
* tokens_per_second
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
|
||||
@ -33,7 +33,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
smart_resize as image_smart_resize,
|
||||
@ -84,6 +84,7 @@ from vllm.utils import is_list_of
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -1174,7 +1175,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@ -1480,6 +1481,116 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
|
||||
@ -410,6 +410,14 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
|
||||
|
||||
|
||||
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
|
||||
ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
|
||||
for num in lst:
|
||||
index = num // interval
|
||||
ranges[index].append(num)
|
||||
return ranges
|
||||
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
|
||||
@ -875,30 +875,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
if supports_mrope(self.get_model()):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
)
|
||||
else:
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
assert supports_mrope(self.get_model()), "M-RoPE support is not implemented."
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_mm_kwargs(
|
||||
self,
|
||||
@ -2900,7 +2889,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
if not supports_eagle3(self.model):
|
||||
if not supports_eagle3(self.get_model()):
|
||||
raise RuntimeError(
|
||||
"Model does not support EAGLE3 interface but "
|
||||
"aux_hidden_state_outputs was requested"
|
||||
@ -2928,7 +2917,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
supports_multimodal_pruning(self.model)
|
||||
supports_multimodal_pruning(self.get_model())
|
||||
and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user