[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-09-12 02:30:41 +08:00 committed by GitHub
parent 361ae27f8a
commit bb2b5126da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 106 deletions

View File

@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -139,37 +138,22 @@ class Idefics2VisionAttention(nn.Module):
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)
for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):

View File

@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module):
use_data_parallel: bool = False,
):
super().__init__()
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
self.fc1 = ColumnParallelLinear(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
self.fc2 = cls_fc2(
self.fc2 = RowParallelLinear(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
self.activation_fn = nn.GELU()
self.output_activation = output_activation
@ -419,20 +418,15 @@ class Llama4UnfoldConvolution(nn.Module):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
params = {
"input_size":
config.num_channels * kernel_size[0] * kernel_size[1],
"output_size": config.hidden_size,
"bias": False,
"quant_config": quant_config,
"prefix": f"{prefix}.linear",
}
if use_data_parallel:
cls = ReplicatedLinear
else:
cls = ColumnParallelLinear
params["gather_output"] = True
self.linear = cls(**params)
self.linear = ColumnParallelLinear(
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
output_size=config.hidden_size,
bias=False,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.linear",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)

View File

@ -49,7 +49,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -510,32 +509,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.mlp = nn.ModuleList([
cls_fc1(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
self.mlp = nn.Sequential(
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0",
return_bias=False,
disable_tp=use_data_parallel,
),
nn.GELU(),
cls_fc2(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
])
RowParallelLinear(
self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2",
return_bias=False,
disable_tp=use_data_parallel,
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
out = self.mlp(x)
return out