mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:04:57 +08:00
[Model] Enable DP for ViT in Qwen2-VL (#25445)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5774b0a1da
commit
c98be0a232
@ -66,6 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@ -217,17 +218,20 @@ class Qwen2VisionMLP(nn.Module):
|
|||||||
act_layer: type[nn.Module] = QuickGELU,
|
act_layer: type[nn.Module] = QuickGELU,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc1 = ColumnParallelLinear(in_features,
|
self.fc1 = ColumnParallelLinear(in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1")
|
prefix=f"{prefix}.fc1",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
self.act = act_layer()
|
self.act = act_layer()
|
||||||
self.fc2 = RowParallelLinear(hidden_features,
|
self.fc2 = RowParallelLinear(hidden_features,
|
||||||
in_features,
|
in_features,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2")
|
prefix=f"{prefix}.fc2",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_parallel, _ = self.fc1(x)
|
x_parallel, _ = self.fc1(x)
|
||||||
@ -293,25 +297,28 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
projection_size: int,
|
projection_size: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
self.tp_size = (1 if use_data_parallel else
|
||||||
self.tp_size = world_size
|
parallel_state.get_tensor_model_parallel_world_size())
|
||||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads)
|
projection_size, num_heads)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, world_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||||
output_size=3 * projection_size,
|
output_size=3 * projection_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv")
|
prefix=f"{prefix}.qkv",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
self.proj = RowParallelLinear(input_size=projection_size,
|
self.proj = RowParallelLinear(input_size=projection_size,
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj")
|
prefix=f"{prefix}.proj",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
@ -453,6 +460,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -465,12 +473,14 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
self.mlp = Qwen2VisionMLP(dim,
|
self.mlp = Qwen2VisionMLP(dim,
|
||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -531,6 +541,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|||||||
spatial_merge_size: int = 2,
|
spatial_merge_size: int = 2,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
@ -542,13 +553,15 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp.0"),
|
prefix=f"{prefix}.mlp.0",
|
||||||
|
disable_tp=use_data_parallel),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
RowParallelLinear(self.hidden_size,
|
RowParallelLinear(self.hidden_size,
|
||||||
d_model,
|
d_model,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp.2"),
|
prefix=f"{prefix}.mlp.2",
|
||||||
|
disable_tp=use_data_parallel),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -600,6 +613,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -613,6 +627,9 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
num_heads = vision_config.num_heads
|
num_heads = vision_config.num_heads
|
||||||
mlp_ratio = vision_config.mlp_ratio
|
mlp_ratio = vision_config.mlp_ratio
|
||||||
|
|
||||||
|
self.use_data_parallel = use_data_parallel
|
||||||
|
self.out_hidden_size = vision_config.hidden_size
|
||||||
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
self.spatial_merge_size = spatial_merge_size
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -634,7 +651,8 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
|
use_data_parallel=use_data_parallel)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
])
|
])
|
||||||
self.merger = Qwen2VisionPatchMerger(
|
self.merger = Qwen2VisionPatchMerger(
|
||||||
@ -643,6 +661,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
@ -659,8 +678,9 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.patch_embed.proj.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
|
max_grid_size = 0
|
||||||
for t, h, w in grid_thw:
|
for t, h, w in grid_thw:
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
@ -678,8 +698,8 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
).permute(0, 2, 1, 3).flatten()
|
).permute(0, 2, 1, 3).flatten()
|
||||||
pos_ids.append(
|
pos_ids.append(
|
||||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
max_grid_size = max(max_grid_size, h, w)
|
||||||
pos_ids = torch.cat(pos_ids, dim=0)
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
max_grid_size = grid_thw[:, 1:].max()
|
|
||||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
@ -698,7 +718,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: torch.Tensor,
|
grid_thw: list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# patchify
|
# patchify
|
||||||
x = x.to(device=self.device, dtype=self.dtype)
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
@ -708,8 +728,9 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
|
||||||
# compute cu_seqlens
|
# compute cu_seqlens
|
||||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
grid_thw_ = torch.tensor(grid_thw)
|
||||||
grid_thw[:, 0]).cumsum(
|
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
|
||||||
|
grid_thw_[:, 0]).cumsum(
|
||||||
dim=0, dtype=torch.int32)
|
dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
"model.": "language_model.model.",
|
"model.": "language_model.model.",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
supports_encoder_tp_data = True
|
||||||
|
|
||||||
def get_mrope_input_positions(
|
def get_mrope_input_positions(
|
||||||
self,
|
self,
|
||||||
input_tokens: list[int],
|
input_tokens: list[int],
|
||||||
@ -1239,6 +1262,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
multimodal_config = vllm_config.model_config.multimodal_config
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
|
||||||
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
@ -1249,6 +1273,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.visual = None
|
self.visual = None
|
||||||
@ -1357,7 +1382,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_embeds = image_input["image_embeds"]
|
image_embeds = image_input["image_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
|
||||||
|
if self.use_data_parallel:
|
||||||
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d")
|
||||||
|
else:
|
||||||
|
image_embeds = self.visual(pixel_values,
|
||||||
|
grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
@ -1377,7 +1410,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
video_embeds = video_input["video_embeds"]
|
video_embeds = video_input["video_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"]
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
if self.use_data_parallel:
|
||||||
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
|
pixel_values_videos,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d")
|
||||||
|
else:
|
||||||
|
video_embeds = self.visual(pixel_values_videos,
|
||||||
|
grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user