mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 01:59:36 +08:00
[Model] Add option to run Step3VisionEncoder in DP (#22697)
Signed-off-by: zzh142857 <chaorenzhaozhenghao@gmail.com>
This commit is contained in:
parent
6807af8f46
commit
d16aa3dae4
@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.qkv_proj = QKVParallelLinear(self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Step3VisionAttention(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.self_attn = Step3VisionAttention(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
self.mlp = Step3VisionMLP(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.layers = nn.ModuleList([
|
||||
Step3VisionEncoderLayer(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{i}")
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.image_size = config.image_size
|
||||
self.embeddings = Step3VisionEmbeddings(config)
|
||||
self.transformer = Step3VisionEncoder(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.transformer = Step3VisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
):
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||
if self.use_data_parallel:
|
||||
hidden_states = run_dp_sharded_vision_model(
|
||||
hidden_states, self.transformer)
|
||||
else:
|
||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_model = Step3VisionTransformer(config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"vision_model"))
|
||||
self.vision_model = Step3VisionTransformer(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.vit_downsampler = nn.Conv2d(
|
||||
config.vision_config.hidden_size,
|
||||
config.vision_config.output_hidden_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user