mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 04:07:02 +08:00
[Model] Support DP for ViT on MiniCPM-V-4 (#23327)
Signed-off-by: ycyaw66 <497410282@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
parent
add1adfec7
commit
23c939fd30
@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-4 (<gh-pr:23327>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
|
||||
@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2Config, Idefics2VisionConfig)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module):
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads_per_partition = self.num_heads // tp_size
|
||||
|
||||
if use_data_parallel:
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics2VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.self_attn = Idefics2VisionAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module):
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module):
|
||||
self.layers = nn.ModuleList([
|
||||
Idefics2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: bool = True,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder")
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
if self.use_data_parallel:
|
||||
encoder_outputs = run_dp_sharded_vision_model(
|
||||
hidden_states, self.encoder)
|
||||
else:
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
if self.use_data_parallel:
|
||||
weights = self._consolidate_qkv_weights(weights)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# skip pooling header
|
||||
if name.startswith("head."):
|
||||
@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
if weight_name not in name or self.use_data_parallel:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
|
||||
@ -778,6 +778,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# and config class
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(vllm_config=vllm_config,
|
||||
@ -1325,9 +1326,11 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
model = Idefics2VisionTransformer(config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
model = Idefics2VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
vision_embeddings = vision_embeddings.contiguous()
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user