mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 03:02:15 +08:00
[Model] Add option to run Step3VisionEncoder in DP (#22697)
Signed-off-by: zzh142857 <chaorenzhaozhenghao@gmail.com>
This commit is contained in:
parent
6807af8f46
commit
d16aa3dae4
@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
|
|||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = (1 if use_data_parallel else
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
self.qkv_proj = QKVParallelLinear(self.embed_dim,
|
|
||||||
self.head_dim,
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.total_num_heads,
|
|
||||||
bias=True,
|
if use_data_parallel:
|
||||||
quant_config=quant_config,
|
self.qkv_proj = ReplicatedLinear(
|
||||||
prefix=prefix)
|
self.embed_dim,
|
||||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
3 * self.q_size,
|
||||||
self.embed_dim,
|
bias=True,
|
||||||
bias=True,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
prefix=prefix,
|
||||||
prefix=prefix)
|
)
|
||||||
|
self.out_proj = ReplicatedLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
self.embed_dim,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
self.embed_dim,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||||
|
self.embed_dim,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads,
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = get_act_fn(config.hidden_act)
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
cls_fc1 = (ReplicatedLinear
|
||||||
config.intermediate_size,
|
if use_data_parallel else ColumnParallelLinear)
|
||||||
bias=True,
|
self.fc1 = cls_fc1(config.hidden_size,
|
||||||
quant_config=quant_config,
|
config.intermediate_size,
|
||||||
prefix=prefix)
|
bias=True,
|
||||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
quant_config=quant_config,
|
||||||
config.hidden_size,
|
prefix=prefix)
|
||||||
bias=True,
|
cls_fc2 = (ReplicatedLinear
|
||||||
quant_config=quant_config,
|
if use_data_parallel else RowParallelLinear)
|
||||||
prefix=prefix)
|
self.fc2 = cls_fc2(config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.fc1(hidden_states)
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Step3VisionEncoderConfig,
|
config: Step3VisionEncoderConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.use_data_parallel = use_data_parallel
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.self_attn = Step3VisionAttention(config,
|
self.self_attn = Step3VisionAttention(
|
||||||
quant_config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn")
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
use_data_parallel=self.use_data_parallel)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
self.mlp = Step3VisionMLP(config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
use_data_parallel=self.use_data_parallel)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Step3VisionEncoderConfig,
|
config: Step3VisionEncoderConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.use_data_parallel = use_data_parallel
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Step3VisionEncoderLayer(config,
|
Step3VisionEncoderLayer(config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.layers.{i}")
|
prefix=f"{prefix}.layers.{i}",
|
||||||
|
use_data_parallel=self.use_data_parallel)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Step3VisionEncoderConfig,
|
config: Step3VisionEncoderConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.use_data_parallel = use_data_parallel
|
||||||
self.image_size = config.image_size
|
self.image_size = config.image_size
|
||||||
self.embeddings = Step3VisionEmbeddings(config)
|
self.embeddings = Step3VisionEmbeddings(config)
|
||||||
self.transformer = Step3VisionEncoder(config,
|
self.transformer = Step3VisionEncoder(
|
||||||
quant_config,
|
config,
|
||||||
prefix=f"{prefix}.transformer")
|
quant_config,
|
||||||
|
prefix=f"{prefix}.transformer",
|
||||||
|
use_data_parallel=self.use_data_parallel)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
):
|
):
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values)
|
||||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
if self.use_data_parallel:
|
||||||
|
hidden_states = run_dp_sharded_vision_model(
|
||||||
|
hidden_states, self.transformer)
|
||||||
|
else:
|
||||||
|
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
self.use_data_parallel = (vllm_config.parallel_config.
|
||||||
|
enable_multimodal_encoder_data_parallel)
|
||||||
|
|
||||||
if multimodal_config.get_limit_per_prompt("image"):
|
if multimodal_config.get_limit_per_prompt("image"):
|
||||||
self.vision_model = Step3VisionTransformer(config.vision_config,
|
self.vision_model = Step3VisionTransformer(
|
||||||
None,
|
config.vision_config,
|
||||||
prefix=maybe_prefix(
|
None,
|
||||||
prefix,
|
prefix=maybe_prefix(prefix, "vision_model"),
|
||||||
"vision_model"))
|
use_data_parallel=self.use_data_parallel)
|
||||||
self.vit_downsampler = nn.Conv2d(
|
self.vit_downsampler = nn.Conv2d(
|
||||||
config.vision_config.hidden_size,
|
config.vision_config.hidden_size,
|
||||||
config.vision_config.output_hidden_size,
|
config.vision_config.output_hidden_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user