mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:45:01 +08:00
[Model] Cleanup InternViT's data parallel implementation (#25306)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
bf8b26cad1
commit
3c713a9711
@ -25,7 +25,6 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -164,15 +163,6 @@ class InternParallelAttention(nn.Module):
|
|||||||
self.tp_size)
|
self.tp_size)
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
if use_data_parallel:
|
|
||||||
self.qkv = ReplicatedLinear(
|
|
||||||
self.embed_dim,
|
|
||||||
3 * self.head_dim * self.num_heads,
|
|
||||||
bias=config.qkv_bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.qkv",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.qkv = QKVParallelLinear(
|
self.qkv = QKVParallelLinear(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -180,6 +170,7 @@ class InternParallelAttention(nn.Module):
|
|||||||
bias=config.qkv_bias,
|
bias=config.qkv_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv",
|
prefix=f"{prefix}.qkv",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qk_normalization = config.qk_normalization
|
self.qk_normalization = config.qk_normalization
|
||||||
@ -192,19 +183,12 @@ class InternParallelAttention(nn.Module):
|
|||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
var_hidden_size=self.embed_dim)
|
var_hidden_size=self.embed_dim)
|
||||||
|
|
||||||
if use_data_parallel:
|
|
||||||
self.proj = ReplicatedLinear(
|
|
||||||
self.dummy_dim,
|
|
||||||
self.embed_dim,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.proj",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.proj = RowParallelLinear(
|
self.proj = RowParallelLinear(
|
||||||
self.dummy_dim,
|
self.dummy_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
@ -236,72 +220,6 @@ class InternParallelAttention(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class InternSdpaAttention(nn.Module):
|
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
*,
|
|
||||||
num_dummy_heads: int = 0,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
|
||||||
if self.head_dim * self.num_heads != self.embed_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f'embed_dim must be divisible by num_heads '
|
|
||||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
|
||||||
f' {self.num_heads}).')
|
|
||||||
|
|
||||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
|
||||||
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
|
||||||
self.qkv = nn.Linear(self.embed_dim,
|
|
||||||
3 * self.dummy_dim,
|
|
||||||
bias=config.qkv_bias)
|
|
||||||
|
|
||||||
self.qk_normalization = config.qk_normalization
|
|
||||||
|
|
||||||
if self.qk_normalization:
|
|
||||||
self.q_norm = RMSNorm(self.dummy_dim,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
var_hidden_size=self.embed_dim)
|
|
||||||
self.k_norm = RMSNorm(self.dummy_dim,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
var_hidden_size=self.embed_dim)
|
|
||||||
|
|
||||||
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
|
||||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
|
||||||
self.scale)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
B, N, C = x.shape
|
|
||||||
qkv = self.qkv(x)
|
|
||||||
q, k, v = qkv.chunk(3, dim=-1)
|
|
||||||
|
|
||||||
q = q.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
v = v.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
if self.qk_normalization:
|
|
||||||
B_, N_, H_, D_ = q.shape
|
|
||||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
|
||||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
|
||||||
x = self.attn(q, k, v)
|
|
||||||
|
|
||||||
x = self.proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class InternMLP(nn.Module):
|
class InternMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -315,20 +233,18 @@ class InternMLP(nn.Module):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = get_act_fn(config.hidden_act)
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
cls_fc1 = (ReplicatedLinear
|
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||||
if use_data_parallel else ColumnParallelLinear)
|
|
||||||
self.fc1 = cls_fc1(config.hidden_size,
|
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1")
|
prefix=f"{prefix}.fc1",
|
||||||
cls_fc2 = (ReplicatedLinear
|
disable_tp=use_data_parallel)
|
||||||
if use_data_parallel else RowParallelLinear)
|
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||||
self.fc2 = cls_fc2(config.intermediate_size,
|
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2")
|
prefix=f"{prefix}.fc2",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.fc1(hidden_states)
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
@ -385,20 +301,20 @@ class InternVisionEncoderLayer(nn.Module):
|
|||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
# fallback to sdpa attention if tp unavailable
|
# fallback to sdpa attention if tp unavailable
|
||||||
# tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_size = (1 if use_data_parallel else
|
tp_size = (1 if use_data_parallel else
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
num_heads = config.num_attention_heads
|
num_heads = config.num_attention_heads
|
||||||
|
|
||||||
if (num_heads + num_dummy_heads) % tp_size == 0:
|
# if the number of heads is not divisible by tp_size,
|
||||||
|
# we also disable Attention's TP
|
||||||
|
use_data_parallel = (use_data_parallel
|
||||||
|
or (num_heads + num_dummy_heads) % tp_size != 0)
|
||||||
return InternParallelAttention(config,
|
return InternParallelAttention(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_dummy_heads=num_dummy_heads,
|
num_dummy_heads=num_dummy_heads,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
use_data_parallel=use_data_parallel)
|
use_data_parallel=use_data_parallel)
|
||||||
|
|
||||||
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user