diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 9f1439e21ef7..3e243385fd04 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -26,6 +26,7 @@ # limitations under the License. """Inference-only GLM-4V model compatible with HuggingFace weights.""" +import itertools import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial @@ -36,7 +37,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, @@ -89,6 +90,7 @@ from ..layers.activation import SiluAndMul from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -1386,7 +1388,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): dummy_inputs=Glm4vDummyInputsBuilder, ) class Glm4vForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): merge_by_field_config = True @@ -1613,6 +1615,149 @@ class Glm4vForConditionalGeneration( multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: "PretrainedConfig", + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def forward( self, input_ids: torch.Tensor,