diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b3c42c257256..88813490c0fb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -66,6 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -217,17 +218,20 @@ class Qwen2VisionMLP(nn.Module): act_layer: type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, quant_config=quant_config, - prefix=f"{prefix}.fc1") + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, quant_config=quant_config, - prefix=f"{prefix}.fc2") + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -293,25 +297,28 @@ class Qwen2VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_size = world_size + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size) self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, - prefix=f"{prefix}.qkv") + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, - prefix=f"{prefix}.proj") + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( @@ -453,6 +460,7 @@ class Qwen2VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -465,12 +473,14 @@ class Qwen2VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -531,6 +541,7 @@ class Qwen2VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -542,13 +553,15 @@ class Qwen2VisionPatchMerger(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + prefix=f"{prefix}.mlp.0", + disable_tp=use_data_parallel), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + prefix=f"{prefix}.mlp.2", + disable_tp=use_data_parallel), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -600,6 +613,7 @@ class Qwen2VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -613,6 +627,9 @@ class Qwen2VisionTransformer(nn.Module): num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads self.embed_dim = embed_dim @@ -634,7 +651,8 @@ class Qwen2VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2VisionPatchMerger( @@ -643,6 +661,7 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) @@ -659,8 +678,9 @@ class Qwen2VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids = [] + max_grid_size = 0 for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) @@ -678,8 +698,8 @@ class Qwen2VisionTransformer(nn.Module): ).permute(0, 2, 1, 3).flatten() pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -698,7 +718,7 @@ class Qwen2VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -708,8 +728,9 @@ class Qwen2VisionTransformer(nn.Module): rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( + grid_thw_ = torch.tensor(grid_thw) + cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2], + grid_thw_[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) @@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + def get_mrope_input_positions( self, input_tokens: list[int], @@ -1239,6 +1262,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config @@ -1249,6 +1273,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -1357,7 +1382,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"] else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1377,7 +1410,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_embeds = video_input["video_embeds"] else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size