mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 21:05:58 +08:00
[MM Encoder] Apply DP ViT for Qwen3-VL model series (#24955)
Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Huang Jie <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: 松灵 <26085463+wulipc@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4ac510f484
commit
3127274d02
@ -126,20 +126,23 @@ class Qwen3_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.linear_fc1 = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.linear_fc1")
|
||||
prefix=f"{prefix}.linear_fc1",
|
||||
disable_tp=use_data_parallel)
|
||||
self.linear_fc2 = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.linear_fc2")
|
||||
prefix=f"{prefix}.linear_fc2",
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -158,23 +161,27 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Qwen2_5_VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.mlp = Qwen3_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -205,6 +212,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
||||
use_postshuffle_norm: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
@ -222,13 +230,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_fc1")
|
||||
prefix=f"{prefix}.linear_fc1",
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = nn.GELU()
|
||||
self.linear_fc2 = RowParallelLinear(self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_fc2")
|
||||
prefix=f"{prefix}.linear_fc2",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_postshuffle_norm:
|
||||
@ -250,6 +260,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -260,6 +271,12 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||
self.use_data_parallel = use_data_parallel
|
||||
|
||||
# NOTE: This is used for creating empty tensor for all_gather for
|
||||
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
||||
self.out_hidden_size = (vision_config.out_hidden_size *
|
||||
(1 + len(self.deepstack_visual_indexes)))
|
||||
|
||||
self.patch_embed = Qwen3_VisionPatchEmbed(
|
||||
patch_size=self.patch_size,
|
||||
@ -283,7 +300,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
@ -294,6 +312,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
@ -304,7 +323,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
use_postshuffle_norm=True,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}")
|
||||
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(len(self.deepstack_visual_indexes))
|
||||
])
|
||||
|
||||
@ -325,7 +345,14 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
# Support both Tensor and list inputs for DP path
|
||||
if isinstance(grid_thw, list):
|
||||
grid_list = grid_thw
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_list)
|
||||
else:
|
||||
grid_list = grid_thw.tolist()
|
||||
max_grid_size = int(grid_thw[:, 1:].max().item())
|
||||
for t, h, w in grid_list:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
@ -348,7 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
@ -453,10 +479,18 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
grid_thw_tensor = grid_thw
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cumsum(
|
||||
dim=0,
|
||||
dtype=grid_thw.dtype
|
||||
dtype=grid_thw_tensor.dtype
|
||||
if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
@ -984,6 +1018,9 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -1009,12 +1046,14 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
|
||||
@ -1177,7 +1216,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
if self.use_data_parallel:
|
||||
from vllm.multimodal.utils import (
|
||||
run_dp_sharded_mrope_vision_model)
|
||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
@ -1199,7 +1246,16 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
if self.use_data_parallel:
|
||||
from vllm.multimodal.utils import (
|
||||
run_dp_sharded_mrope_vision_model)
|
||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos,
|
||||
grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
|
||||
@ -315,12 +315,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user