[MM] Move Qwen3Omni MRoPE impl to model file (#26608)

Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Roger Wang 2025-10-10 22:17:24 -07:00 committed by GitHub
parent 27ed39a347
commit ddaff2938e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 368 additions and 387 deletions

View File

@ -426,7 +426,7 @@ class MRotaryEmbedding(RotaryEmbedding):
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config): if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni":
return cls._omni_get_input_positions_tensor( return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens, input_tokens=input_tokens,
hf_config=hf_config, hf_config=hf_config,
@ -1119,339 +1119,6 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: torch.Tensor,
image_grid_thw: torch.Tensor,
video_grid_thw: torch.Tensor,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.Tensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
seq_len = input_ids.shape[0]
device = input_ids.device
dtype = input_ids.dtype
if image_grid_thw is not None:
image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long)
if video_grid_thw is not None:
video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long)
if second_per_grids is None:
if video_grid_thw is not None and video_grid_thw.numel() > 0:
second_per_grids = torch.ones(
video_grid_thw.shape[0], dtype=torch.float32, device=device
)
else:
second_per_grids = torch.tensor([], dtype=torch.float32, device=device)
else:
second_per_grids = second_per_grids.to(device=device, dtype=torch.float32)
if audio_seqlens is not None:
audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long)
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
input_tokens = input_ids.tolist()
llm_pos_ids_list: list[torch.Tensor] = []
st = 0
image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
llm_pos_ids = (
torch.arange(audio_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == image_token_id
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t, device=device) * position_id_per_seconds
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t, device=device)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
audio_llm_pos_ids = (
torch.arange(audio_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t, device=device)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len, device=device, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length")
position_ids = llm_positions.to(device=device, dtype=dtype)
mrope_position_delta = llm_positions.max() + 1 - seq_len
return position_ids, mrope_position_delta
@classmethod @classmethod
def _omni_get_input_positions_tensor( def _omni_get_input_positions_tensor(
cls, cls,
@ -1483,8 +1150,6 @@ class MRotaryEmbedding(RotaryEmbedding):
# TODO(fyabc): refactor and share more code with # TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor. # _vl_get_input_positions_tensor.
model_type = hf_config.model_type
thinker_config = hf_config.thinker_config thinker_config = hf_config.thinker_config
if isinstance(image_grid_thw, list): if isinstance(image_grid_thw, list):
@ -1492,30 +1157,6 @@ class MRotaryEmbedding(RotaryEmbedding):
if isinstance(video_grid_thw, list): if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw) video_grid_thw = torch.tensor(video_grid_thw)
if "qwen3_omni" in model_type:
input_tensor = torch.tensor(input_tokens)
audio_lengths_tensor = audio_feature_lengths
if audio_lengths_tensor is not None and not isinstance(
audio_lengths_tensor, torch.Tensor
):
audio_lengths_tensor = torch.as_tensor(
audio_lengths_tensor, dtype=torch.long
)
second_per_grids_tensor = (
torch.tensor(second_per_grid_ts) if second_per_grid_ts else None
)
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501
thinker_config,
input_tensor,
image_grid_thw,
video_grid_thw,
use_audio_in_video,
audio_lengths_tensor,
second_per_grids_tensor,
)
return llm_positions, mrope_position_delta
audio_token_id = thinker_config.audio_token_index audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index video_token_id = thinker_config.video_token_index

View File

@ -72,7 +72,12 @@ from vllm.multimodal.processing import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (
MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -96,7 +101,7 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
try: try:
import flash_attn import flash_attn
@ -106,6 +111,15 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return feat_lengths, output_lengths
class Qwen3_VisionPatchEmbed(nn.Module): class Qwen3_VisionPatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
@ -679,16 +693,6 @@ Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
class Qwen3OmniMoeThinkerMultiModalProcessor( class Qwen3OmniMoeThinkerMultiModalProcessor(
Qwen2_5OmniThinkerMultiModalProcessor, Qwen2_5OmniThinkerMultiModalProcessor,
): ):
def _get_feat_extract_output_lengths(
self, input_lengths: torch.Tensor
) -> torch.Tensor:
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return feat_lengths, output_lengths
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
@ -882,13 +886,13 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if audio_feature_lengths is None and feature_attention_mask is None: if audio_feature_lengths is None and feature_attention_mask is None:
audio_output_lengths = [] audio_output_lengths = []
elif audio_feature_lengths is not None: elif audio_feature_lengths is not None:
_, audio_output_lens = self._get_feat_extract_output_lengths( _, audio_output_lens = _get_feat_extract_output_lengths(
audio_feature_lengths audio_feature_lengths
) )
audio_output_lengths = audio_output_lens.tolist() audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None: elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor) assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lens = self._get_feat_extract_output_lengths( _, audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1) feature_attention_mask.sum(-1)
) )
audio_output_lengths = audio_output_lens.tolist() audio_output_lengths = audio_output_lens.tolist()
@ -1044,16 +1048,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
else: else:
return torch.concat(mm_input, dim=dim) return torch.concat(mm_input, dim=dim)
def _get_feat_extract_output_lengths(
self, input_lengths: torch.Tensor
) -> torch.Tensor:
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths, output_lengths
def _process_audio_input( def _process_audio_input(
self, self,
audio_input: Qwen2AudioFeatureInputs, audio_input: Qwen2AudioFeatureInputs,
@ -1072,8 +1066,8 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
if audio_feature_lengths.ndim == 2: if audio_feature_lengths.ndim == 2:
audio_feature_lengths = audio_feature_lengths.reshape(-1) audio_feature_lengths = audio_feature_lengths.reshape(-1)
audio_feat_lengths, audio_output_lengths = ( audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
self._get_feat_extract_output_lengths(audio_feature_lengths) audio_feature_lengths
) )
audio_outputs = self.audio_tower( audio_outputs = self.audio_tower(
@ -1094,6 +1088,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
nn.Module, nn.Module,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
SupportsMRoPE,
Qwen3OmniMoeConditionalGenerationMixin, Qwen3OmniMoeConditionalGenerationMixin,
): ):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
@ -1407,3 +1402,311 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights return loaded_weights
@classmethod
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
input_ids = torch.tensor(input_tokens)
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
seq_len = input_ids.shape[0]
if audio_feature_lengths is not None and not isinstance(
audio_feature_lengths, torch.Tensor
):
audio_feature_lengths = torch.as_tensor(
audio_feature_lengths, dtype=torch.long
)
if second_per_grid_ts is None:
if video_grid_thw is not None and video_grid_thw.numel() > 0:
second_per_grids = torch.ones(
video_grid_thw.shape[0], dtype=torch.float32
)
else:
second_per_grids = torch.tensor([], dtype=torch.float32)
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
llm_pos_ids_list: list[torch.Tensor] = []
st = 0
image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == image_token_id
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * position_id_per_seconds
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = llm_positions.max() + 1 - seq_len
return llm_positions, mrope_position_delta

View File

@ -499,3 +499,40 @@ def run_dp_sharded_mrope_vision_model(
"Found unassigned embeddings" "Found unassigned embeddings"
) )
return out_embeddings return out_embeddings
def get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: list[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(len(t_index), -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(len(t_index), llm_grid_h, -1)
.flatten()
)
t_index_tensor = (
torch.Tensor(t_index)
.to(llm_grid_h.device)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.long()
.flatten()
)
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids

View File

@ -875,7 +875,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if mm_input.get("use_audio_in_video") is True: if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True use_audio_in_video = True
if supports_mrope(self.model): if supports_mrope(self.get_model()):
req_state.mrope_positions, req_state.mrope_position_delta = ( req_state.mrope_positions, req_state.mrope_position_delta = (
self.model.get_mrope_input_positions( self.model.get_mrope_input_positions(
req_state.prompt_token_ids, req_state.prompt_token_ids,