mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[MM] Move Qwen3Omni MRoPE impl to model file (#26608)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
27ed39a347
commit
ddaff2938e
@ -426,7 +426,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||
|
||||
if thinker_uses_mrope(hf_config):
|
||||
if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni":
|
||||
return cls._omni_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
@ -1119,339 +1119,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _omni3_get_input_positions_tensor(
|
||||
cls,
|
||||
config,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
video_grid_thw: torch.Tensor,
|
||||
use_audio_in_video: bool = False,
|
||||
audio_seqlens: Optional[torch.Tensor] = None,
|
||||
second_per_grids: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
if input_ids is None or input_ids.ndim != 1:
|
||||
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
|
||||
if image_grid_thw is not None:
|
||||
image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long)
|
||||
if video_grid_thw is not None:
|
||||
video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long)
|
||||
|
||||
if second_per_grids is None:
|
||||
if video_grid_thw is not None and video_grid_thw.numel() > 0:
|
||||
second_per_grids = torch.ones(
|
||||
video_grid_thw.shape[0], dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
second_per_grids = torch.tensor([], dtype=torch.float32, device=device)
|
||||
else:
|
||||
second_per_grids = second_per_grids.to(device=device, dtype=torch.float32)
|
||||
|
||||
if audio_seqlens is not None:
|
||||
audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long)
|
||||
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
image_token_id = config.image_token_id
|
||||
video_token_id = config.video_token_id
|
||||
audio_token_id = config.audio_token_id
|
||||
vision_start_token_id = config.vision_start_token_id
|
||||
audio_start_token_id = config.audio_start_token_id
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_ids == vision_start_token_id
|
||||
).squeeze(1)
|
||||
if vision_start_indices.numel() > 0:
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
else:
|
||||
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
|
||||
input_tokens = input_ids.tolist()
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
st = 0
|
||||
image_idx = 0
|
||||
video_idx = 0
|
||||
audio_idx = 0
|
||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums
|
||||
if use_audio_in_video
|
||||
else image_nums + video_nums + audio_nums
|
||||
) # noqa: E501
|
||||
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||
remain_videos > 0 or remain_images > 0
|
||||
):
|
||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||
else:
|
||||
ed_vision_start = len(input_tokens) + 1
|
||||
if audio_token_id in input_tokens and remain_audios > 0:
|
||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||
else:
|
||||
ed_audio_start = len(input_tokens) + 1
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
|
||||
if min_ed == ed_audio_start:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
llm_pos_ids = (
|
||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + audio_len + eos_len
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == image_token_id
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = torch.arange(grid_t, device=device) * position_id_per_seconds
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + image_len + eos_len
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == video_token_id
|
||||
and not use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t, device=device)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + video_len + eos_len
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
and use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
bos_block = (
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
audio_llm_pos_ids = (
|
||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t, device=device)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if (
|
||||
video_llm_pos_ids[0][video_data_index]
|
||||
<= audio_llm_pos_ids[0][audio_data_index]
|
||||
):
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_data_index + 1
|
||||
]
|
||||
)
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_data_index + 1
|
||||
]
|
||||
)
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
eos_block = (
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
if llm_positions.shape[1] != seq_len:
|
||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||
|
||||
position_ids = llm_positions.to(device=device, dtype=dtype)
|
||||
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
||||
return position_ids, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _omni_get_input_positions_tensor(
|
||||
cls,
|
||||
@ -1483,8 +1150,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
|
||||
model_type = hf_config.model_type
|
||||
thinker_config = hf_config.thinker_config
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
@ -1492,30 +1157,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
if "qwen3_omni" in model_type:
|
||||
input_tensor = torch.tensor(input_tokens)
|
||||
audio_lengths_tensor = audio_feature_lengths
|
||||
if audio_lengths_tensor is not None and not isinstance(
|
||||
audio_lengths_tensor, torch.Tensor
|
||||
):
|
||||
audio_lengths_tensor = torch.as_tensor(
|
||||
audio_lengths_tensor, dtype=torch.long
|
||||
)
|
||||
second_per_grids_tensor = (
|
||||
torch.tensor(second_per_grid_ts) if second_per_grid_ts else None
|
||||
)
|
||||
|
||||
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501
|
||||
thinker_config,
|
||||
input_tensor,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
use_audio_in_video,
|
||||
audio_lengths_tensor,
|
||||
second_per_grids_tensor,
|
||||
)
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
|
||||
@ -72,7 +72,12 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -96,7 +101,7 @@ from .utils import (
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vit_attn_backend
|
||||
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@ -106,6 +111,15 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -679,16 +693,6 @@ Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
|
||||
class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
):
|
||||
def _get_feat_extract_output_lengths(
|
||||
self, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -882,13 +886,13 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
if audio_feature_lengths is None and feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
elif audio_feature_lengths is not None:
|
||||
_, audio_output_lens = self._get_feat_extract_output_lengths(
|
||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths
|
||||
)
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
elif feature_attention_mask is not None:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lens = self._get_feat_extract_output_lengths(
|
||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
@ -1044,16 +1048,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
||||
else:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _get_feat_extract_output_lengths(
|
||||
self, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths, output_lengths
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2AudioFeatureInputs,
|
||||
@ -1072,8 +1066,8 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
||||
if audio_feature_lengths.ndim == 2:
|
||||
audio_feature_lengths = audio_feature_lengths.reshape(-1)
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = (
|
||||
self._get_feat_extract_output_lengths(audio_feature_lengths)
|
||||
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths
|
||||
)
|
||||
|
||||
audio_outputs = self.audio_tower(
|
||||
@ -1094,6 +1088,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
Qwen3OmniMoeConditionalGenerationMixin,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -1407,3 +1402,311 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
config = hf_config.thinker_config
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
input_ids = torch.tensor(input_tokens)
|
||||
if input_ids is None or input_ids.ndim != 1:
|
||||
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
if audio_feature_lengths is not None and not isinstance(
|
||||
audio_feature_lengths, torch.Tensor
|
||||
):
|
||||
audio_feature_lengths = torch.as_tensor(
|
||||
audio_feature_lengths, dtype=torch.long
|
||||
)
|
||||
if second_per_grid_ts is None:
|
||||
if video_grid_thw is not None and video_grid_thw.numel() > 0:
|
||||
second_per_grids = torch.ones(
|
||||
video_grid_thw.shape[0], dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
second_per_grids = torch.tensor([], dtype=torch.float32)
|
||||
else:
|
||||
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
||||
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
image_token_id = config.image_token_id
|
||||
video_token_id = config.video_token_id
|
||||
audio_token_id = config.audio_token_id
|
||||
vision_start_token_id = config.vision_start_token_id
|
||||
audio_start_token_id = config.audio_start_token_id
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_ids == vision_start_token_id
|
||||
).squeeze(1)
|
||||
if vision_start_indices.numel() > 0:
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
else:
|
||||
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
st = 0
|
||||
image_idx = 0
|
||||
video_idx = 0
|
||||
audio_idx = 0
|
||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums
|
||||
if use_audio_in_video
|
||||
else image_nums + video_nums + audio_nums
|
||||
) # noqa: E501
|
||||
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||
remain_videos > 0 or remain_images > 0
|
||||
):
|
||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||
else:
|
||||
ed_vision_start = len(input_tokens) + 1
|
||||
if audio_token_id in input_tokens and remain_audios > 0:
|
||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||
else:
|
||||
ed_audio_start = len(input_tokens) + 1
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
|
||||
if min_ed == ed_audio_start:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
_, audio_len = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths[audio_idx]
|
||||
)
|
||||
llm_pos_ids = (
|
||||
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + audio_len + eos_len
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == image_token_id
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = torch.arange(grid_t) * position_id_per_seconds
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + image_len + eos_len
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == video_token_id
|
||||
and not use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + video_len + eos_len
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
and use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
bos_block = (
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
_, audio_len = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths[audio_idx]
|
||||
)
|
||||
audio_llm_pos_ids = (
|
||||
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
video_llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if (
|
||||
video_llm_pos_ids[0][video_data_index]
|
||||
<= audio_llm_pos_ids[0][audio_data_index]
|
||||
):
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_data_index + 1
|
||||
]
|
||||
)
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_data_index + 1
|
||||
]
|
||||
)
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
eos_block = (
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
if llm_positions.shape[1] != seq_len:
|
||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||
|
||||
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@ -499,3 +499,40 @@ def run_dp_sharded_mrope_vision_model(
|
||||
"Found unassigned embeddings"
|
||||
)
|
||||
return out_embeddings
|
||||
|
||||
|
||||
def get_llm_pos_ids_for_vision(
|
||||
start_idx: int,
|
||||
vision_idx: int,
|
||||
spatial_merge_size: int,
|
||||
t_index: list[int],
|
||||
grid_hs: torch.Tensor,
|
||||
grid_ws: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
llm_pos_ids_list = []
|
||||
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(len(t_index), -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(len(t_index), llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
t_index_tensor = (
|
||||
torch.Tensor(t_index)
|
||||
.to(llm_grid_h.device)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
|
||||
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||
return llm_pos_ids
|
||||
|
||||
@ -875,7 +875,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
if supports_mrope(self.model):
|
||||
if supports_mrope(self.get_model()):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user