[Model] enable data parallel for InternVL vision encoder (#23909)

Signed-off-by: Yiwen Chen <yiwen66@berkeley.edu>
Signed-off-by: YiwenC <54658925+666even666@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
YiwenC 2025-09-17 21:11:46 -07:00 committed by GitHub
parent dc2979c585
commit 52bc9d5b3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 33 deletions

View File

@ -175,6 +175,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
Known supported models:
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
- InternVL (<gh-pr:23909>)
- Kimi-VL (<gh-pr:23817>)
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)

View File

@ -25,9 +25,11 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import run_dp_sharded_vision_model
NORM2FN = {
'rms_norm': RMSNorm,
@ -137,6 +139,7 @@ class InternParallelAttention(nn.Module):
*,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -150,8 +153,10 @@ class InternParallelAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = (0 if use_data_parallel else
get_tensor_model_parallel_rank())
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
@ -159,14 +164,23 @@ class InternParallelAttention(nn.Module):
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
if use_data_parallel:
self.qkv = ReplicatedLinear(
self.embed_dim,
3 * self.head_dim * self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
else:
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.qk_normalization = config.qk_normalization
@ -178,12 +192,20 @@ class InternParallelAttention(nn.Module):
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
if use_data_parallel:
self.proj = ReplicatedLinear(
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = RowParallelLinear(
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@ -287,21 +309,26 @@ class InternMLP(nn.Module):
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
@ -320,6 +347,7 @@ class InternVisionEncoderLayer(nn.Module):
*,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -330,11 +358,13 @@ class InternVisionEncoderLayer(nn.Module):
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@ -352,16 +382,20 @@ class InternVisionEncoderLayer(nn.Module):
*,
num_dummy_heads: int,
prefix: str = "",
use_data_parallel: bool = False,
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
# tp_size = get_tensor_model_parallel_world_size()
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
num_heads = config.num_attention_heads
if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
prefix=prefix)
prefix=prefix,
use_data_parallel=use_data_parallel)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
@ -388,6 +422,7 @@ class InternVisionEncoder(nn.Module):
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
@ -402,7 +437,8 @@ class InternVisionEncoder(nn.Module):
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers)
])
@ -429,10 +465,12 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
@ -441,6 +479,7 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
)
def get_input_embeddings(self):
@ -464,7 +503,11 @@ class InternVisionModel(nn.Module):
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs

View File

@ -1035,6 +1035,8 @@ class InternVLMultiModalProcessor(
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -1053,6 +1055,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size
@ -1120,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
use_data_parallel=self.use_data_parallel)
else:
return InternVisionPatchModel(config.vision_config)