diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 118cce810a1f..892188c04722 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -164,23 +163,15 @@ class InternParallelAttention(nn.Module): self.tp_size) self.scale = self.head_dim**-0.5 - if use_data_parallel: - self.qkv = ReplicatedLinear( - self.embed_dim, - 3 * self.head_dim * self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) - else: - self.qkv = QKVParallelLinear( - self.embed_dim, - self.head_dim, - num_dummy_heads + self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) self.qk_normalization = config.qk_normalization @@ -192,20 +183,13 @@ class InternParallelAttention(nn.Module): eps=config.layer_norm_eps, var_hidden_size=self.embed_dim) - if use_data_parallel: - self.proj = ReplicatedLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) - else: - self.proj = RowParallelLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -236,72 +220,6 @@ class InternParallelAttention(nn.Module): return out -class InternSdpaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: PretrainedConfig, - *, - num_dummy_heads: int = 0, - ) -> None: - super().__init__() - - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') - - # Additional dummy heads are used to enable TP for common GPU counts. - self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim - - self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.dummy_dim, - bias=config.qkv_bias) - - self.qk_normalization = config.qk_normalization - - if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - - self.proj = nn.Linear(self.dummy_dim, self.embed_dim) - - # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, - self.scale) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.chunk(3, dim=-1) - - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - - if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) - k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - - # Use unified MultiHeadAttention with automatic backend selection - x = self.attn(q, k, v) - - x = self.proj(x) - return x - - class InternMLP(nn.Module): def __init__( @@ -315,20 +233,18 @@ class InternMLP(nn.Module): self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -385,19 +301,19 @@ class InternVisionEncoderLayer(nn.Module): use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - # tp_size = get_tensor_model_parallel_world_size() tp_size = (1 if use_data_parallel else get_tensor_model_parallel_world_size()) num_heads = config.num_attention_heads - if (num_heads + num_dummy_heads) % tp_size == 0: - return InternParallelAttention(config, - quant_config=quant_config, - num_dummy_heads=num_dummy_heads, - prefix=prefix, - use_data_parallel=use_data_parallel) - - return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + # if the number of heads is not divisible by tp_size, + # we also disable Attention's TP + use_data_parallel = (use_data_parallel + or (num_heads + num_dummy_heads) % tp_size != 0) + return InternParallelAttention(config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix, + use_data_parallel=use_data_parallel) def forward( self,