diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index b0ea9621d545..0ab2ae58ad86 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u Known supported models: +- GLM-4.5V GLM-4.1V () - Kimi-VL () - Llama4 () - MiniCPM-V-2.5 or above (, ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 662728e6b139..f9fd5163d66b 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -45,15 +45,20 @@ from transformers.models.glm4v.video_processing_glm4v import ( from transformers.video_utils import VideoMetadata from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.distributed import (get_tensor_model_parallel_world_size, + parallel_state) from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -66,6 +71,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] -# === Vision Encoder === # +# ==== Vision Encoder ==== # class Glm4vVisionMLP(nn.Module): @@ -165,19 +171,23 @@ class Glm4vVisionMLP(nn.Module): bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - input_size=in_features, - output_sizes=[hidden_features] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + cls_gate_up = (MergedReplicatedLinear + if use_data_parallel else MergedColumnParallelLinear) + self.gate_up_proj = cls_gate_up(input_size=in_features, + output_sizes=[hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + cls_down = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.down_proj = cls_down(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor): @@ -218,33 +228,54 @@ class Glm4vVisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=False, - quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization config - prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", - ) - self.proj = RowParallelLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - bias=False, - ) + if use_data_parallel: + self.qkv = ReplicatedLinear( + input_size=embed_dim, + output_size=3 * projection_size, + bias=False, + quant_config=quant_config, + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg + prefix=f"{prefix}.qkv_proj" + if quant_config else f"{prefix}.qkv", + ) + self.proj = ReplicatedLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + bias=False, + ) + else: + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=False, + quant_config=quant_config, + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg + prefix=f"{prefix}.qkv_proj" + if quant_config else f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + bias=False, + ) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -375,6 +406,7 @@ class Glm4vVisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -387,6 +419,7 @@ class Glm4vVisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, ) self.mlp = Glm4vVisionMLP( dim, @@ -394,6 +427,7 @@ class Glm4vVisionBlock(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, ) def forward( @@ -456,24 +490,40 @@ class Glm4vPatchMerger(nn.Module): quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = d_model - self.proj = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=bias, - gather_output=True, - quant_config=quant_config, - prefix=f"{prefix}.proj") + if use_data_parallel: + self.proj = ReplicatedLinear( + input_size=self.hidden_size, + output_size=self.hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + else: + self.proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) self.post_projection_norm = nn.LayerNorm(self.hidden_size) - self.gate_up_proj = MergedColumnParallelLinear( + cls_gate_up = (MergedReplicatedLinear + if use_data_parallel else MergedColumnParallelLinear) + self.gate_up_proj = cls_gate_up( input_size=self.hidden_size, output_sizes=[context_dim] * 2, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) - self.down_proj = RowParallelLinear( + cls_down = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.down_proj = cls_down( context_dim, self.hidden_size, bias=bias, @@ -548,14 +598,33 @@ class Glm4vVisionEmbeddings(nn.Module): dtype=torch.float32)) # Calculate target dimensions for each patch - target_h = torch.cat([ - image_shapes[i, 1].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) - target_w = torch.cat([ - image_shapes[i, 2].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) + # Add bounds checking for data parallel mode + if len(lengths) > image_shapes.shape[0]: + # In data parallel mode, some GPUs might not have all + # image shapes + # Use available image shapes, cycling if necessary + target_h_list = [] + target_w_list = [] + for i in range(len(lengths)): + # Cycle through available shapes + shape_idx = i % image_shapes.shape[0] + target_h_list.append(image_shapes[shape_idx, + 1].repeat(lengths[i])) + target_w_list.append(image_shapes[shape_idx, + 2].repeat(lengths[i])) + target_h = torch.cat(target_h_list).to(device=device, + dtype=torch.float32) + target_w = torch.cat(target_w_list).to(device=device, + dtype=torch.float32) + else: + target_h = torch.cat([ + image_shapes[i, 1].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) + target_w = torch.cat([ + image_shapes[i, 2].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample h_coords = h_coords.to(device=device, dtype=torch.float32) @@ -629,6 +698,7 @@ class Glm4vVisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -638,6 +708,7 @@ class Glm4vVisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -661,6 +732,7 @@ class Glm4vVisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=self.use_data_parallel, ) for layer_idx in range(depth) ]) self.merger = Glm4vPatchMerger( @@ -669,6 +741,7 @@ class Glm4vVisionTransformer(nn.Module): quant_config=quant_config, bias=False, prefix=f"{prefix}.merger", + use_data_parallel=self.use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) @@ -731,8 +804,11 @@ class Glm4vVisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: + # Convert grid_thw to tensor (always expecting list format now) + grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) @@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "model.visual.": "visual.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -1267,12 +1345,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) if config.model_type == "glm4v": @@ -1382,8 +1462,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw.tolist(), + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return image_embeds.split(sizes.tolist()) @@ -1393,23 +1479,22 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - device = self.visual.device - flat_grid_thw = torch.cat([ - torch.tensor([[1, h, w]] * t, device=device) - for t, h, w in grid_thw - ]) if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=flat_grid_thw) - + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw.tolist()) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: