mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[MM] Move Qwen3Omni MRoPE impl to model file (#26608)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
27ed39a347
commit
ddaff2938e
@ -426,7 +426,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
) -> tuple[torch.Tensor, int]:
|
) -> tuple[torch.Tensor, int]:
|
||||||
from vllm.transformers_utils.config import thinker_uses_mrope
|
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||||
|
|
||||||
if thinker_uses_mrope(hf_config):
|
if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni":
|
||||||
return cls._omni_get_input_positions_tensor(
|
return cls._omni_get_input_positions_tensor(
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
hf_config=hf_config,
|
hf_config=hf_config,
|
||||||
@ -1119,339 +1119,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
return llm_positions, mrope_position_delta
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _omni3_get_input_positions_tensor(
|
|
||||||
cls,
|
|
||||||
config,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
image_grid_thw: torch.Tensor,
|
|
||||||
video_grid_thw: torch.Tensor,
|
|
||||||
use_audio_in_video: bool = False,
|
|
||||||
audio_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
second_per_grids: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
|
||||||
input_lengths_leave = input_lengths % 100
|
|
||||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
|
||||||
output_lengths = (
|
|
||||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
|
||||||
)
|
|
||||||
return output_lengths
|
|
||||||
|
|
||||||
if input_ids is None or input_ids.ndim != 1:
|
|
||||||
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
|
||||||
|
|
||||||
seq_len = input_ids.shape[0]
|
|
||||||
device = input_ids.device
|
|
||||||
dtype = input_ids.dtype
|
|
||||||
|
|
||||||
if image_grid_thw is not None:
|
|
||||||
image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long)
|
|
||||||
if video_grid_thw is not None:
|
|
||||||
video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
if second_per_grids is None:
|
|
||||||
if video_grid_thw is not None and video_grid_thw.numel() > 0:
|
|
||||||
second_per_grids = torch.ones(
|
|
||||||
video_grid_thw.shape[0], dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
second_per_grids = torch.tensor([], dtype=torch.float32, device=device)
|
|
||||||
else:
|
|
||||||
second_per_grids = second_per_grids.to(device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
if audio_seqlens is not None:
|
|
||||||
audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
|
||||||
image_token_id = config.image_token_id
|
|
||||||
video_token_id = config.video_token_id
|
|
||||||
audio_token_id = config.audio_token_id
|
|
||||||
vision_start_token_id = config.vision_start_token_id
|
|
||||||
audio_start_token_id = config.audio_start_token_id
|
|
||||||
position_id_per_seconds = config.position_id_per_seconds
|
|
||||||
|
|
||||||
vision_start_indices = torch.argwhere(
|
|
||||||
input_ids == vision_start_token_id
|
|
||||||
).squeeze(1)
|
|
||||||
if vision_start_indices.numel() > 0:
|
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
|
||||||
else:
|
|
||||||
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
|
||||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
|
||||||
image_nums = (vision_tokens == image_token_id).sum()
|
|
||||||
video_nums = (
|
|
||||||
(vision_tokens == audio_start_token_id).sum()
|
|
||||||
if use_audio_in_video
|
|
||||||
else (vision_tokens == video_token_id).sum()
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tokens = input_ids.tolist()
|
|
||||||
llm_pos_ids_list: list[torch.Tensor] = []
|
|
||||||
st = 0
|
|
||||||
image_idx = 0
|
|
||||||
video_idx = 0
|
|
||||||
audio_idx = 0
|
|
||||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
|
||||||
multimodal_nums = (
|
|
||||||
image_nums + audio_nums
|
|
||||||
if use_audio_in_video
|
|
||||||
else image_nums + video_nums + audio_nums
|
|
||||||
) # noqa: E501
|
|
||||||
|
|
||||||
for _ in range(multimodal_nums):
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
|
||||||
remain_videos > 0 or remain_images > 0
|
|
||||||
):
|
|
||||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
|
||||||
else:
|
|
||||||
ed_vision_start = len(input_tokens) + 1
|
|
||||||
if audio_token_id in input_tokens and remain_audios > 0:
|
|
||||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
|
||||||
else:
|
|
||||||
ed_audio_start = len(input_tokens) + 1
|
|
||||||
min_ed = min(ed_vision_start, ed_audio_start)
|
|
||||||
|
|
||||||
if min_ed == ed_audio_start:
|
|
||||||
text_len = min_ed - st
|
|
||||||
if text_len != 0:
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
bos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
|
||||||
llm_pos_ids = (
|
|
||||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(llm_pos_ids)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
eos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st += text_len + bos_len + audio_len + eos_len
|
|
||||||
audio_idx += 1
|
|
||||||
remain_audios -= 1
|
|
||||||
elif (
|
|
||||||
min_ed == ed_vision_start
|
|
||||||
and input_ids[ed_vision_start + 1] == image_token_id
|
|
||||||
):
|
|
||||||
text_len = min_ed - st
|
|
||||||
if text_len != 0:
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
bos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
grid_t = image_grid_thw[image_idx][0]
|
|
||||||
grid_hs = image_grid_thw[:, 1]
|
|
||||||
grid_ws = image_grid_thw[:, 2]
|
|
||||||
t_index = torch.arange(grid_t, device=device) * position_id_per_seconds
|
|
||||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
|
||||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
||||||
)
|
|
||||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
|
||||||
llm_pos_ids_list.append(llm_pos_ids)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
eos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st += text_len + bos_len + image_len + eos_len
|
|
||||||
image_idx += 1
|
|
||||||
remain_images -= 1
|
|
||||||
elif (
|
|
||||||
min_ed == ed_vision_start
|
|
||||||
and input_ids[ed_vision_start + 1] == video_token_id
|
|
||||||
and not use_audio_in_video
|
|
||||||
):
|
|
||||||
text_len = min_ed - st
|
|
||||||
if text_len != 0:
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
bos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
grid_t = video_grid_thw[video_idx][0]
|
|
||||||
grid_hs = video_grid_thw[:, 1]
|
|
||||||
grid_ws = video_grid_thw[:, 2]
|
|
||||||
t_index = (
|
|
||||||
torch.arange(grid_t, device=device)
|
|
||||||
* float(second_per_grids[video_idx].item())
|
|
||||||
* position_id_per_seconds
|
|
||||||
)
|
|
||||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
|
||||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
||||||
)
|
|
||||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
|
||||||
llm_pos_ids_list.append(llm_pos_ids)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
eos_len = 1
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st += text_len + bos_len + video_len + eos_len
|
|
||||||
video_idx += 1
|
|
||||||
remain_videos -= 1
|
|
||||||
elif (
|
|
||||||
min_ed == ed_vision_start
|
|
||||||
and ed_vision_start + 1 == ed_audio_start
|
|
||||||
and use_audio_in_video
|
|
||||||
):
|
|
||||||
text_len = min_ed - st
|
|
||||||
if text_len != 0:
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
bos_len = 1
|
|
||||||
bos_block = (
|
|
||||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(bos_block)
|
|
||||||
llm_pos_ids_list.append(bos_block)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
|
||||||
audio_llm_pos_ids = (
|
|
||||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
grid_t = video_grid_thw[video_idx][0]
|
|
||||||
grid_hs = video_grid_thw[:, 1]
|
|
||||||
grid_ws = video_grid_thw[:, 2]
|
|
||||||
t_index = (
|
|
||||||
torch.arange(grid_t, device=device)
|
|
||||||
* float(second_per_grids[video_idx].item())
|
|
||||||
* position_id_per_seconds
|
|
||||||
)
|
|
||||||
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
|
||||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
|
||||||
)
|
|
||||||
video_data_index, audio_data_index = 0, 0
|
|
||||||
while (
|
|
||||||
video_data_index < video_llm_pos_ids.shape[-1]
|
|
||||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
video_llm_pos_ids[0][video_data_index]
|
|
||||||
<= audio_llm_pos_ids[0][audio_data_index]
|
|
||||||
):
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
video_llm_pos_ids[
|
|
||||||
:, video_data_index : video_data_index + 1
|
|
||||||
]
|
|
||||||
)
|
|
||||||
video_data_index += 1
|
|
||||||
else:
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
audio_llm_pos_ids[
|
|
||||||
:, audio_data_index : audio_data_index + 1
|
|
||||||
]
|
|
||||||
)
|
|
||||||
audio_data_index += 1
|
|
||||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
video_llm_pos_ids[
|
|
||||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
audio_llm_pos_ids[
|
|
||||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
eos_len = 1
|
|
||||||
eos_block = (
|
|
||||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(eos_block)
|
|
||||||
llm_pos_ids_list.append(eos_block)
|
|
||||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
|
||||||
audio_idx += 1
|
|
||||||
video_idx += 1
|
|
||||||
remain_videos -= 1
|
|
||||||
remain_audios -= 1
|
|
||||||
|
|
||||||
if st < len(input_tokens):
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
text_len = len(input_tokens) - st
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len, device=device, dtype=torch.long)
|
|
||||||
.view(1, -1)
|
|
||||||
.expand(3, -1)
|
|
||||||
+ st_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
||||||
if llm_positions.shape[1] != seq_len:
|
|
||||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
|
||||||
|
|
||||||
position_ids = llm_positions.to(device=device, dtype=dtype)
|
|
||||||
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
|
||||||
return position_ids, mrope_position_delta
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _omni_get_input_positions_tensor(
|
def _omni_get_input_positions_tensor(
|
||||||
cls,
|
cls,
|
||||||
@ -1483,8 +1150,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
# TODO(fyabc): refactor and share more code with
|
# TODO(fyabc): refactor and share more code with
|
||||||
# _vl_get_input_positions_tensor.
|
# _vl_get_input_positions_tensor.
|
||||||
|
|
||||||
model_type = hf_config.model_type
|
|
||||||
thinker_config = hf_config.thinker_config
|
thinker_config = hf_config.thinker_config
|
||||||
|
|
||||||
if isinstance(image_grid_thw, list):
|
if isinstance(image_grid_thw, list):
|
||||||
@ -1492,30 +1157,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
if isinstance(video_grid_thw, list):
|
if isinstance(video_grid_thw, list):
|
||||||
video_grid_thw = torch.tensor(video_grid_thw)
|
video_grid_thw = torch.tensor(video_grid_thw)
|
||||||
|
|
||||||
if "qwen3_omni" in model_type:
|
|
||||||
input_tensor = torch.tensor(input_tokens)
|
|
||||||
audio_lengths_tensor = audio_feature_lengths
|
|
||||||
if audio_lengths_tensor is not None and not isinstance(
|
|
||||||
audio_lengths_tensor, torch.Tensor
|
|
||||||
):
|
|
||||||
audio_lengths_tensor = torch.as_tensor(
|
|
||||||
audio_lengths_tensor, dtype=torch.long
|
|
||||||
)
|
|
||||||
second_per_grids_tensor = (
|
|
||||||
torch.tensor(second_per_grid_ts) if second_per_grid_ts else None
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501
|
|
||||||
thinker_config,
|
|
||||||
input_tensor,
|
|
||||||
image_grid_thw,
|
|
||||||
video_grid_thw,
|
|
||||||
use_audio_in_video,
|
|
||||||
audio_lengths_tensor,
|
|
||||||
second_per_grids_tensor,
|
|
||||||
)
|
|
||||||
return llm_positions, mrope_position_delta
|
|
||||||
|
|
||||||
audio_token_id = thinker_config.audio_token_index
|
audio_token_id = thinker_config.audio_token_index
|
||||||
image_token_id = thinker_config.image_token_index
|
image_token_id = thinker_config.image_token_index
|
||||||
video_token_id = thinker_config.video_token_index
|
video_token_id = thinker_config.video_token_index
|
||||||
|
|||||||
@ -72,7 +72,12 @@ from vllm.multimodal.processing import (
|
|||||||
)
|
)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import (
|
||||||
|
MultiModalEmbeddings,
|
||||||
|
SupportsMRoPE,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -96,7 +101,7 @@ from .utils import (
|
|||||||
_merge_multimodal_embeddings,
|
_merge_multimodal_embeddings,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
@ -106,6 +111,15 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
output_lengths = (
|
||||||
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
)
|
||||||
|
return feat_lengths, output_lengths
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -679,16 +693,6 @@ Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
|
|||||||
class Qwen3OmniMoeThinkerMultiModalProcessor(
|
class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||||
):
|
):
|
||||||
def _get_feat_extract_output_lengths(
|
|
||||||
self, input_lengths: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
input_lengths_leave = input_lengths % 100
|
|
||||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
|
||||||
output_lengths = (
|
|
||||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
|
||||||
)
|
|
||||||
return feat_lengths, output_lengths
|
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -882,13 +886,13 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
if audio_feature_lengths is None and feature_attention_mask is None:
|
if audio_feature_lengths is None and feature_attention_mask is None:
|
||||||
audio_output_lengths = []
|
audio_output_lengths = []
|
||||||
elif audio_feature_lengths is not None:
|
elif audio_feature_lengths is not None:
|
||||||
_, audio_output_lens = self._get_feat_extract_output_lengths(
|
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
audio_feature_lengths
|
audio_feature_lengths
|
||||||
)
|
)
|
||||||
audio_output_lengths = audio_output_lens.tolist()
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
elif feature_attention_mask is not None:
|
elif feature_attention_mask is not None:
|
||||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||||
_, audio_output_lens = self._get_feat_extract_output_lengths(
|
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
feature_attention_mask.sum(-1)
|
feature_attention_mask.sum(-1)
|
||||||
)
|
)
|
||||||
audio_output_lengths = audio_output_lens.tolist()
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
@ -1044,16 +1048,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
|||||||
else:
|
else:
|
||||||
return torch.concat(mm_input, dim=dim)
|
return torch.concat(mm_input, dim=dim)
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(
|
|
||||||
self, input_lengths: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
input_lengths_leave = input_lengths % 100
|
|
||||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
|
||||||
output_lengths = (
|
|
||||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
|
||||||
)
|
|
||||||
return output_lengths, output_lengths
|
|
||||||
|
|
||||||
def _process_audio_input(
|
def _process_audio_input(
|
||||||
self,
|
self,
|
||||||
audio_input: Qwen2AudioFeatureInputs,
|
audio_input: Qwen2AudioFeatureInputs,
|
||||||
@ -1072,8 +1066,8 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
|||||||
if audio_feature_lengths.ndim == 2:
|
if audio_feature_lengths.ndim == 2:
|
||||||
audio_feature_lengths = audio_feature_lengths.reshape(-1)
|
audio_feature_lengths = audio_feature_lengths.reshape(-1)
|
||||||
|
|
||||||
audio_feat_lengths, audio_output_lengths = (
|
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||||
self._get_feat_extract_output_lengths(audio_feature_lengths)
|
audio_feature_lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_outputs = self.audio_tower(
|
audio_outputs = self.audio_tower(
|
||||||
@ -1094,6 +1088,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
nn.Module,
|
nn.Module,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
|
SupportsMRoPE,
|
||||||
Qwen3OmniMoeConditionalGenerationMixin,
|
Qwen3OmniMoeConditionalGenerationMixin,
|
||||||
):
|
):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -1407,3 +1402,311 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
return loaded_weights
|
return loaded_weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
second_per_grid_ts: Optional[list[float]] = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
config = hf_config.thinker_config
|
||||||
|
if isinstance(image_grid_thw, list):
|
||||||
|
image_grid_thw = torch.tensor(image_grid_thw)
|
||||||
|
if isinstance(video_grid_thw, list):
|
||||||
|
video_grid_thw = torch.tensor(video_grid_thw)
|
||||||
|
input_ids = torch.tensor(input_tokens)
|
||||||
|
if input_ids is None or input_ids.ndim != 1:
|
||||||
|
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
||||||
|
|
||||||
|
seq_len = input_ids.shape[0]
|
||||||
|
if audio_feature_lengths is not None and not isinstance(
|
||||||
|
audio_feature_lengths, torch.Tensor
|
||||||
|
):
|
||||||
|
audio_feature_lengths = torch.as_tensor(
|
||||||
|
audio_feature_lengths, dtype=torch.long
|
||||||
|
)
|
||||||
|
if second_per_grid_ts is None:
|
||||||
|
if video_grid_thw is not None and video_grid_thw.numel() > 0:
|
||||||
|
second_per_grids = torch.ones(
|
||||||
|
video_grid_thw.shape[0], dtype=torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
second_per_grids = torch.tensor([], dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
||||||
|
|
||||||
|
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
image_token_id = config.image_token_id
|
||||||
|
video_token_id = config.video_token_id
|
||||||
|
audio_token_id = config.audio_token_id
|
||||||
|
vision_start_token_id = config.vision_start_token_id
|
||||||
|
audio_start_token_id = config.audio_start_token_id
|
||||||
|
position_id_per_seconds = config.position_id_per_seconds
|
||||||
|
|
||||||
|
vision_start_indices = torch.argwhere(
|
||||||
|
input_ids == vision_start_token_id
|
||||||
|
).squeeze(1)
|
||||||
|
if vision_start_indices.numel() > 0:
|
||||||
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
|
else:
|
||||||
|
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
||||||
|
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (
|
||||||
|
(vision_tokens == audio_start_token_id).sum()
|
||||||
|
if use_audio_in_video
|
||||||
|
else (vision_tokens == video_token_id).sum()
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_pos_ids_list: list[torch.Tensor] = []
|
||||||
|
st = 0
|
||||||
|
image_idx = 0
|
||||||
|
video_idx = 0
|
||||||
|
audio_idx = 0
|
||||||
|
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
||||||
|
multimodal_nums = (
|
||||||
|
image_nums + audio_nums
|
||||||
|
if use_audio_in_video
|
||||||
|
else image_nums + video_nums + audio_nums
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
|
for _ in range(multimodal_nums):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||||
|
remain_videos > 0 or remain_images > 0
|
||||||
|
):
|
||||||
|
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_vision_start = len(input_tokens) + 1
|
||||||
|
if audio_token_id in input_tokens and remain_audios > 0:
|
||||||
|
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_audio_start = len(input_tokens) + 1
|
||||||
|
min_ed = min(ed_vision_start, ed_audio_start)
|
||||||
|
|
||||||
|
if min_ed == ed_audio_start:
|
||||||
|
text_len = min_ed - st
|
||||||
|
if text_len != 0:
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len, dtype=torch.long)
|
||||||
|
.view(1, -1)
|
||||||
|
.expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
bos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
_, audio_len = _get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths[audio_idx]
|
||||||
|
)
|
||||||
|
llm_pos_ids = (
|
||||||
|
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
eos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st += text_len + bos_len + audio_len + eos_len
|
||||||
|
audio_idx += 1
|
||||||
|
remain_audios -= 1
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and input_ids[ed_vision_start + 1] == image_token_id
|
||||||
|
):
|
||||||
|
text_len = min_ed - st
|
||||||
|
if text_len != 0:
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len, dtype=torch.long)
|
||||||
|
.view(1, -1)
|
||||||
|
.expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
bos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
grid_t = image_grid_thw[image_idx][0]
|
||||||
|
grid_hs = image_grid_thw[:, 1]
|
||||||
|
grid_ws = image_grid_thw[:, 2]
|
||||||
|
t_index = torch.arange(grid_t) * position_id_per_seconds
|
||||||
|
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||||
|
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||||
|
)
|
||||||
|
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
eos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st += text_len + bos_len + image_len + eos_len
|
||||||
|
image_idx += 1
|
||||||
|
remain_images -= 1
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and input_ids[ed_vision_start + 1] == video_token_id
|
||||||
|
and not use_audio_in_video
|
||||||
|
):
|
||||||
|
text_len = min_ed - st
|
||||||
|
if text_len != 0:
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len, dtype=torch.long)
|
||||||
|
.view(1, -1)
|
||||||
|
.expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
bos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
t_index = (
|
||||||
|
torch.arange(grid_t)
|
||||||
|
* float(second_per_grids[video_idx].item())
|
||||||
|
* position_id_per_seconds
|
||||||
|
)
|
||||||
|
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||||
|
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||||
|
)
|
||||||
|
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
eos_len = 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st += text_len + bos_len + video_len + eos_len
|
||||||
|
video_idx += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and ed_vision_start + 1 == ed_audio_start
|
||||||
|
and use_audio_in_video
|
||||||
|
):
|
||||||
|
text_len = min_ed - st
|
||||||
|
if text_len != 0:
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len, dtype=torch.long)
|
||||||
|
.view(1, -1)
|
||||||
|
.expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
bos_len = 1
|
||||||
|
bos_block = (
|
||||||
|
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(bos_block)
|
||||||
|
llm_pos_ids_list.append(bos_block)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
_, audio_len = _get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths[audio_idx]
|
||||||
|
)
|
||||||
|
audio_llm_pos_ids = (
|
||||||
|
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
t_index = (
|
||||||
|
torch.arange(grid_t)
|
||||||
|
* float(second_per_grids[video_idx].item())
|
||||||
|
* position_id_per_seconds
|
||||||
|
)
|
||||||
|
video_llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||||
|
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||||
|
)
|
||||||
|
video_data_index, audio_data_index = 0, 0
|
||||||
|
while (
|
||||||
|
video_data_index < video_llm_pos_ids.shape[-1]
|
||||||
|
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
video_llm_pos_ids[0][video_data_index]
|
||||||
|
<= audio_llm_pos_ids[0][audio_data_index]
|
||||||
|
):
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
video_llm_pos_ids[
|
||||||
|
:, video_data_index : video_data_index + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
video_data_index += 1
|
||||||
|
else:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
audio_llm_pos_ids[
|
||||||
|
:, audio_data_index : audio_data_index + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
audio_data_index += 1
|
||||||
|
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
video_llm_pos_ids[
|
||||||
|
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
audio_llm_pos_ids[
|
||||||
|
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
eos_len = 1
|
||||||
|
eos_block = (
|
||||||
|
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(eos_block)
|
||||||
|
llm_pos_ids_list.append(eos_block)
|
||||||
|
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
||||||
|
audio_idx += 1
|
||||||
|
video_idx += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
remain_audios -= 1
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||||
|
+ st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
if llm_positions.shape[1] != seq_len:
|
||||||
|
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||||
|
|
||||||
|
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|||||||
@ -499,3 +499,40 @@ def run_dp_sharded_mrope_vision_model(
|
|||||||
"Found unassigned embeddings"
|
"Found unassigned embeddings"
|
||||||
)
|
)
|
||||||
return out_embeddings
|
return out_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_pos_ids_for_vision(
|
||||||
|
start_idx: int,
|
||||||
|
vision_idx: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
t_index: list[int],
|
||||||
|
grid_hs: torch.Tensor,
|
||||||
|
grid_ws: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||||
|
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(len(t_index), -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(len(t_index), llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
t_index_tensor = (
|
||||||
|
torch.Tensor(t_index)
|
||||||
|
.to(llm_grid_h.device)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.long()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
|
||||||
|
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||||
|
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||||
|
return llm_pos_ids
|
||||||
|
|||||||
@ -875,7 +875,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if mm_input.get("use_audio_in_video") is True:
|
if mm_input.get("use_audio_in_video") is True:
|
||||||
use_audio_in_video = True
|
use_audio_in_video = True
|
||||||
|
|
||||||
if supports_mrope(self.model):
|
if supports_mrope(self.get_model()):
|
||||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||||
self.model.get_mrope_input_positions(
|
self.model.get_mrope_input_positions(
|
||||||
req_state.prompt_token_ids,
|
req_state.prompt_token_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user