[Model][Qwen3VL] Use mm_position to compute mrope positions (#28730)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Lukas Geiger 2025-11-15 03:45:11 +00:00 committed by GitHub
parent 9fc81ec765
commit f05d474c8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,7 +24,7 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from itertools import islice
from typing import Any
@ -1412,72 +1412,47 @@ class Qwen3VLForConditionalGeneration(
)
return mm_input_by_modality
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]:
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, h // spatial_merge_size, w // spatial_merge_size
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_array = np.array(input_tokens)
vision_start_mask = input_tokens_array == vision_start_token_id
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
image_nums = np.count_nonzero(vision_tokens == image_token_id)
video_nums = np.count_nonzero(vision_tokens == video_token_id)
llm_pos_ids_list: list = []
llm_pos_ids_list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0