[Model] Add option to run Step3VisionEncoder in DP (#22697)

Signed-off-by: zzh142857 <chaorenzhaozhenghao@gmail.com>
This commit is contained in:
zzh142857 2025-08-13 03:09:13 -04:00 committed by GitHub
parent 6807af8f46
commit d16aa3dae4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
def __init__(self, def __init__(self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
tp_size = get_tensor_model_parallel_world_size() tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.qkv_proj = QKVParallelLinear(self.embed_dim,
self.head_dim, self.q_size = self.num_heads * self.head_dim
self.total_num_heads,
bias=True, if use_data_parallel:
quant_config=quant_config, self.qkv_proj = ReplicatedLinear(
prefix=prefix) self.embed_dim,
self.out_proj = RowParallelLinear(self.embed_dim, 3 * self.q_size,
self.embed_dim, bias=True,
bias=True, quant_config=quant_config,
quant_config=quant_config, prefix=prefix,
prefix=prefix) )
self.out_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, return tensor.view(bsz, seq_len, self.num_heads,
@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
def __init__(self, def __init__(self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size, cls_fc1 = (ReplicatedLinear
config.intermediate_size, if use_data_parallel else ColumnParallelLinear)
bias=True, self.fc1 = cls_fc1(config.hidden_size,
quant_config=quant_config, config.intermediate_size,
prefix=prefix) bias=True,
self.fc2 = RowParallelLinear(config.intermediate_size, quant_config=quant_config,
config.hidden_size, prefix=prefix)
bias=True, cls_fc2 = (ReplicatedLinear
quant_config=quant_config, if use_data_parallel else RowParallelLinear)
prefix=prefix) self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
self.use_data_parallel = use_data_parallel
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention(config, self.self_attn = Step3VisionAttention(
quant_config, config,
prefix=f"{prefix}.self_attn") quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=self.use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") self.mlp = Step3VisionMLP(config,
quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
def __init__(self, def __init__(self,
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Step3VisionEncoderLayer(config, Step3VisionEncoderLayer(config,
quant_config, quant_config,
prefix=f"{prefix}.layers.{i}") prefix=f"{prefix}.layers.{i}",
use_data_parallel=self.use_data_parallel)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
def __init__(self, def __init__(self,
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel
self.image_size = config.image_size self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config) self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(config, self.transformer = Step3VisionEncoder(
quant_config, config,
prefix=f"{prefix}.transformer") quant_config,
prefix=f"{prefix}.transformer",
use_data_parallel=self.use_data_parallel)
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
): ):
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states) if self.use_data_parallel:
hidden_states = run_dp_sharded_vision_model(
hidden_states, self.transformer)
else:
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states return hidden_states
@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer(config.vision_config, self.vision_model = Step3VisionTransformer(
None, config.vision_config,
prefix=maybe_prefix( None,
prefix, prefix=maybe_prefix(prefix, "vision_model"),
"vision_model")) use_data_parallel=self.use_data_parallel)
self.vit_downsampler = nn.Conv2d( self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size, config.vision_config.hidden_size,
config.vision_config.output_hidden_size, config.vision_config.output_hidden_size,