mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:56:00 +08:00
[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
361ae27f8a
commit
bb2b5126da
@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -139,37 +138,22 @@ class Idefics2VisionAttention(nn.Module):
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads_per_partition = self.num_heads // tp_size
|
||||
|
||||
if use_data_parallel:
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
# Use unified MultiHeadAttention with Flash Attention support
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
if self.use_data_parallel:
|
||||
weights = self._consolidate_qkv_weights(weights)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# skip pooling header
|
||||
if name.startswith("head."):
|
||||
|
||||
@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module):
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
||||
self.fc2 = cls_fc2(
|
||||
self.fc2 = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.activation_fn = nn.GELU()
|
||||
self.output_activation = output_activation
|
||||
@ -419,20 +418,15 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
params = {
|
||||
"input_size":
|
||||
config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
"output_size": config.hidden_size,
|
||||
"bias": False,
|
||||
"quant_config": quant_config,
|
||||
"prefix": f"{prefix}.linear",
|
||||
}
|
||||
if use_data_parallel:
|
||||
cls = ReplicatedLinear
|
||||
else:
|
||||
cls = ColumnParallelLinear
|
||||
params["gather_output"] = True
|
||||
self.linear = cls(**params)
|
||||
self.linear = ColumnParallelLinear(
|
||||
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
|
||||
@ -49,7 +49,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -510,32 +509,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.ln_q = norm_layer(context_dim)
|
||||
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.mlp = nn.ModuleList([
|
||||
cls_fc1(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.0"),
|
||||
self.mlp = nn.Sequential(
|
||||
ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.0",
|
||||
return_bias=False,
|
||||
disable_tp=use_data_parallel,
|
||||
),
|
||||
nn.GELU(),
|
||||
cls_fc2(self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.2"),
|
||||
])
|
||||
RowParallelLinear(
|
||||
self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.2",
|
||||
return_bias=False,
|
||||
disable_tp=use_data_parallel,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
|
||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||
x_parallel, _ = mlp_fc1(x)
|
||||
x_parallel = mlp_act(x_parallel)
|
||||
out, _ = mlp_fc2(x_parallel)
|
||||
out = self.mlp(x)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user