[Model] Support DP for ViT on MiniCPM-V-4 (#23327)

Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
WeiQing Chen 2025-08-23 10:14:41 +08:00 committed by GitHub
parent add1adfec7
commit 23c939fd30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 105 additions and 30 deletions

View File

@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-4 (<gh-pr:23327>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)

View File

@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import run_dp_sharded_vision_model
class Idefics2VisionEmbeddings(nn.Module):
@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module):
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module):
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module):
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
config.intermediate_size,
config.hidden_size,
bias=True,
@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module):
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = Idefics2VisionAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module):
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module):
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers)
])
@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module):
num_hidden_layers_override: Optional[int] = None,
require_post_norm: bool = True,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder")
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module):
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
encoder_outputs = self.encoder(hidden_states)
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)
for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):
@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
if weight_name not in name or self.use_data_parallel:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]

View File

@ -778,6 +778,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# and config class
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(vllm_config=vllm_config,
@ -1325,9 +1326,11 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=self.use_data_parallel)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model

View File

@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
num_chunks_per_rank, ...]
vision_embeddings = vision_model(image_input_per_rank)
# Ensure tensor is contiguous before all_gather
vision_embeddings = vision_embeddings.contiguous()
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
dim=0)
vision_embeddings = vision_embeddings[:num_chunks, ...]