[MM Encoder] Apply DP ViT for Qwen3-VL model series (#24955)

Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Huang Jie <92386084+JJJYmmm@users.noreply.github.com>
Co-authored-by: 松灵 <26085463+wulipc@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Roger Wang 2025-09-17 21:04:21 -07:00 committed by GitHub
parent 4ac510f484
commit 3127274d02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 19 deletions

View File

@ -126,20 +126,23 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1")
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
self.linear_fc2 = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2")
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
@ -158,23 +161,27 @@ class Qwen3_VisionBlock(nn.Module):
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
def forward(
self,
@ -205,6 +212,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
@ -222,13 +230,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1")
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2")
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
@ -250,6 +260,7 @@ class Qwen3_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@ -260,6 +271,12 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = (vision_config.out_hidden_size *
(1 + len(self.deepstack_visual_indexes)))
self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
@ -283,7 +300,8 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])
@ -294,6 +312,7 @@ class Qwen3_VisionTransformer(nn.Module):
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.deepstack_merger_list = nn.ModuleList([
@ -304,7 +323,8 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}")
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(len(self.deepstack_visual_indexes))
])
@ -325,7 +345,14 @@ class Qwen3_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
# Support both Tensor and list inputs for DP path
if isinstance(grid_thw, list):
grid_list = grid_thw
max_grid_size = max(max(h, w) for _, h, w in grid_list)
else:
grid_list = grid_thw.tolist()
max_grid_size = int(grid_thw[:, 1:].max().item())
for t, h, w in grid_list:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
@ -348,7 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
@ -453,10 +479,18 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
if isinstance(grid_thw, list):
grid_thw_tensor = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.int32)
else:
grid_thw_tensor = grid_thw
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype
dtype=grid_thw_tensor.dtype
if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
@ -984,6 +1018,9 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -1009,12 +1046,14 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
@ -1177,7 +1216,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
if self.use_data_parallel:
from vllm.multimodal.utils import (
run_dp_sharded_mrope_vision_model)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@ -1199,7 +1246,16 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
if self.use_data_parallel:
from vllm.multimodal.utils import (
run_dp_sharded_mrope_vision_model)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync

View File

@ -315,12 +315,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,