mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 01:25:01 +08:00
[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
361ae27f8a
commit
bb2b5126da
@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -139,29 +138,13 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
assert self.num_heads % tp_size == 0
|
assert self.num_heads % tp_size == 0
|
||||||
self.num_heads_per_partition = self.num_heads // tp_size
|
self.num_heads_per_partition = self.num_heads // tp_size
|
||||||
|
|
||||||
if use_data_parallel:
|
|
||||||
self.q_size = self.num_heads * self.head_dim
|
|
||||||
self.qkv_proj = ReplicatedLinear(
|
|
||||||
self.embed_dim,
|
|
||||||
3 * self.q_size,
|
|
||||||
bias=True,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.qkv_proj",
|
|
||||||
)
|
|
||||||
self.out_proj = ReplicatedLinear(
|
|
||||||
self.embed_dim,
|
|
||||||
self.embed_dim,
|
|
||||||
bias=True,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.out_proj",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
@ -169,6 +152,7 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
# Use unified MultiHeadAttention with Flash Attention support
|
# Use unified MultiHeadAttention with Flash Attention support
|
||||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = get_act_fn(config.hidden_act)
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
cls_fc1 = (ReplicatedLinear
|
self.fc1 = ColumnParallelLinear(
|
||||||
if use_data_parallel else ColumnParallelLinear)
|
|
||||||
self.fc1 = cls_fc1(
|
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
cls_fc2 = (ReplicatedLinear
|
self.fc2 = RowParallelLinear(
|
||||||
if use_data_parallel else RowParallelLinear)
|
|
||||||
self.fc2 = cls_fc2(
|
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2",
|
prefix=f"{prefix}.fc2",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
|
|
||||||
def _consolidate_qkv_weights(
|
|
||||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
|
||||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
||||||
qkv_idx_mappings = {
|
|
||||||
".self_attn.q_proj": 0,
|
|
||||||
".self_attn.k_proj": 1,
|
|
||||||
".self_attn.v_proj": 2,
|
|
||||||
}
|
|
||||||
qkv_weights = {}
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
for weight_name, idx in qkv_idx_mappings.items():
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
|
||||||
if new_name not in qkv_weights:
|
|
||||||
qkv_weights[new_name] = [None] * 3
|
|
||||||
qkv_weights[new_name][idx] = loaded_weight
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
yield name, loaded_weight
|
|
||||||
for key, weight in qkv_weights.items():
|
|
||||||
qkv_weight = torch.cat(weight, dim=0)
|
|
||||||
yield key, qkv_weight
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
layer_count = len(self.encoder.layers)
|
layer_count = len(self.encoder.layers)
|
||||||
|
|
||||||
if self.use_data_parallel:
|
|
||||||
weights = self._consolidate_qkv_weights(weights)
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# skip pooling header
|
# skip pooling header
|
||||||
if name.startswith("head."):
|
if name.startswith("head."):
|
||||||
|
|||||||
@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module):
|
|||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cls_fc1 = (ReplicatedLinear
|
self.fc1 = ColumnParallelLinear(
|
||||||
if use_data_parallel else ColumnParallelLinear)
|
|
||||||
self.fc1 = cls_fc1(
|
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
output_size=intermediate_size,
|
output_size=intermediate_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
self.fc2 = RowParallelLinear(
|
||||||
self.fc2 = cls_fc2(
|
|
||||||
input_size=intermediate_size,
|
input_size=intermediate_size,
|
||||||
output_size=output_size,
|
output_size=output_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2",
|
prefix=f"{prefix}.fc2",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.activation_fn = nn.GELU()
|
self.activation_fn = nn.GELU()
|
||||||
self.output_activation = output_activation
|
self.output_activation = output_activation
|
||||||
@ -419,20 +418,15 @@ class Llama4UnfoldConvolution(nn.Module):
|
|||||||
kernel_size = (kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size)
|
||||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||||
stride=config.patch_size)
|
stride=config.patch_size)
|
||||||
params = {
|
self.linear = ColumnParallelLinear(
|
||||||
"input_size":
|
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
|
||||||
config.num_channels * kernel_size[0] * kernel_size[1],
|
output_size=config.hidden_size,
|
||||||
"output_size": config.hidden_size,
|
bias=False,
|
||||||
"bias": False,
|
gather_output=True,
|
||||||
"quant_config": quant_config,
|
quant_config=quant_config,
|
||||||
"prefix": f"{prefix}.linear",
|
prefix=f"{prefix}.linear",
|
||||||
}
|
disable_tp=use_data_parallel,
|
||||||
if use_data_parallel:
|
)
|
||||||
cls = ReplicatedLinear
|
|
||||||
else:
|
|
||||||
cls = ColumnParallelLinear
|
|
||||||
params["gather_output"] = True
|
|
||||||
self.linear = cls(**params)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.unfold(hidden_states)
|
hidden_states = self.unfold(hidden_states)
|
||||||
|
|||||||
@ -49,7 +49,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -510,32 +509,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|||||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
self.ln_q = norm_layer(context_dim)
|
self.ln_q = norm_layer(context_dim)
|
||||||
|
|
||||||
cls_fc1 = (ReplicatedLinear
|
self.mlp = nn.Sequential(
|
||||||
if use_data_parallel else ColumnParallelLinear)
|
ColumnParallelLinear(
|
||||||
cls_fc2 = (ReplicatedLinear
|
self.hidden_size,
|
||||||
if use_data_parallel else RowParallelLinear)
|
|
||||||
self.mlp = nn.ModuleList([
|
|
||||||
cls_fc1(self.hidden_size,
|
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp.0"),
|
prefix=f"{prefix}.mlp.0",
|
||||||
|
return_bias=False,
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
|
),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
cls_fc2(self.hidden_size,
|
RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
d_model,
|
d_model,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp.2"),
|
prefix=f"{prefix}.mlp.2",
|
||||||
])
|
return_bias=False,
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.ln_q(x)
|
x = self.ln_q(x)
|
||||||
x = x.view(-1, self.hidden_size)
|
x = x.view(-1, self.hidden_size)
|
||||||
|
out = self.mlp(x)
|
||||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
|
||||||
x_parallel, _ = mlp_fc1(x)
|
|
||||||
x_parallel = mlp_act(x_parallel)
|
|
||||||
out, _ = mlp_fc2(x_parallel)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user