[Model] Add option to run Step3VisionEncoder in DP (#22697)

Signed-off-by: zzh142857 <chaorenzhaozhenghao@gmail.com>
This commit is contained in:
zzh142857 2025-08-13 03:09:13 -04:00 committed by GitHub
parent 6807af8f46
commit d16aa3dae4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
self.scale = self.head_dim**-0.5
tp_size = get_tensor_model_parallel_world_size()
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.qkv_proj = QKVParallelLinear(self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.q_size = self.num_heads * self.head_dim
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.use_data_parallel = use_data_parallel
self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention(config,
quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = Step3VisionAttention(
config,
quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=self.use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.mlp = Step3VisionMLP(config,
quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.layers = nn.ModuleList([
Step3VisionEncoderLayer(config,
quant_config,
prefix=f"{prefix}.layers.{i}")
prefix=f"{prefix}.layers.{i}",
use_data_parallel=self.use_data_parallel)
for i in range(config.num_hidden_layers)
])
@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
def __init__(self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(config,
quant_config,
prefix=f"{prefix}.transformer")
self.transformer = Step3VisionEncoder(
config,
quant_config,
prefix=f"{prefix}.transformer",
use_data_parallel=self.use_data_parallel)
def forward(
self,
pixel_values: torch.Tensor,
):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states)
if self.use_data_parallel:
hidden_states = run_dp_sharded_vision_model(
hidden_states, self.transformer)
else:
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states
@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer(config.vision_config,
None,
prefix=maybe_prefix(
prefix,
"vision_model"))
self.vision_model = Step3VisionTransformer(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel)
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,