[Bugfix] Fix broken MRoPE for GLM-4.1V/GLM-4.5V (#27860)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-01 01:04:51 +08:00 committed by GitHub
parent 675704ac01
commit 7e06c40e63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
import itertools
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
@ -36,7 +37,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers import BatchFeature, PretrainedConfig
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
def forward(
self,
input_ids: torch.Tensor,