mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 00:46:48 +08:00
[Bugfix] Fix broken MRoPE for GLM-4.1V/GLM-4.5V (#27860)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
675704ac01
commit
7e06c40e63
@ -26,6 +26,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -36,7 +37,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||||
from transformers.models.glm4v.image_processing_glm4v import (
|
from transformers.models.glm4v.image_processing_glm4v import (
|
||||||
Glm4vImageProcessor,
|
Glm4vImageProcessor,
|
||||||
@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul
|
|||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
|
SupportsMRoPE,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
|||||||
dummy_inputs=Glm4vDummyInputsBuilder,
|
dummy_inputs=Glm4vDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Glm4vForConditionalGeneration(
|
class Glm4vForConditionalGeneration(
|
||||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
|
|
||||||
@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration(
|
|||||||
multimodal_embeddings += tuple(video_embeddings)
|
multimodal_embeddings += tuple(video_embeddings)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: "PretrainedConfig",
|
||||||
|
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
second_per_grid_ts: list[float] | None = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: int | None = None,
|
||||||
|
audio_feature_lengths: torch.Tensor | None = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""Get mrope input positions and delta value for GLM4V."""
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_start_token_id = hf_config.video_start_token_id
|
||||||
|
video_end_token_id = hf_config.video_end_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
if not (image_grid_thw is None and video_grid_thw is None):
|
||||||
|
if isinstance(image_grid_thw, torch.Tensor):
|
||||||
|
image_grid_thw = image_grid_thw.tolist()
|
||||||
|
|
||||||
|
input_token_type: list[str] = []
|
||||||
|
video_check_flg = False
|
||||||
|
for token in input_tokens:
|
||||||
|
if token == video_start_token_id:
|
||||||
|
video_check_flg = True
|
||||||
|
elif token == video_end_token_id:
|
||||||
|
video_check_flg = False
|
||||||
|
|
||||||
|
if (token == image_token_id) and (video_check_flg is False):
|
||||||
|
input_token_type.append("image")
|
||||||
|
elif (token == image_token_id) and (video_check_flg is True):
|
||||||
|
input_token_type.append("video")
|
||||||
|
else:
|
||||||
|
input_token_type.append("text")
|
||||||
|
|
||||||
|
input_type_group: list[tuple[str, int, int]] = []
|
||||||
|
for key, group_iter in itertools.groupby(
|
||||||
|
enumerate(input_token_type), lambda x: x[1]
|
||||||
|
):
|
||||||
|
group_list = list(group_iter)
|
||||||
|
start_index = group_list[0][0]
|
||||||
|
end_index = group_list[-1][0] + 1
|
||||||
|
input_type_group.append((key, start_index, end_index))
|
||||||
|
|
||||||
|
video_frame_num = 1
|
||||||
|
mm_data_idx = 0
|
||||||
|
for modality_type, start_idx, end_idx in input_type_group:
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
)
|
||||||
|
if modality_type == "image":
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[mm_data_idx][0],
|
||||||
|
image_grid_thw[mm_data_idx][1],
|
||||||
|
image_grid_thw[mm_data_idx][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
mm_data_idx += 1
|
||||||
|
|
||||||
|
elif modality_type == "video":
|
||||||
|
t, h, w = (
|
||||||
|
video_frame_num,
|
||||||
|
image_grid_thw[mm_data_idx][1],
|
||||||
|
image_grid_thw[mm_data_idx][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t_idx in range(llm_grid_t):
|
||||||
|
t_index = (
|
||||||
|
torch.tensor(t_idx)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(1, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(1, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data_idx += 1
|
||||||
|
video_frame_num += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
text_len = end_idx - start_idx
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
video_frame_num = 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
text_len = len(input_tokens)
|
||||||
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user