mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[Model] enable data parallel for Llama4 vision encoder (#18368)
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
parent
5b168b6d7a
commit
ebb1ec9318
@ -1790,6 +1790,10 @@ class ParallelConfig:
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||
Only support LLama4 for now"""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
|
||||
@ -423,6 +423,9 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = \
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
@ -637,6 +640,9 @@ class EngineArgs:
|
||||
**parallel_kwargs["worker_cls"])
|
||||
parallel_group.add_argument("--worker-extension-cls",
|
||||
**parallel_kwargs["worker_extension_cls"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-multimodal-encoder-data-parallel",
|
||||
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
|
||||
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
@ -1078,6 +1084,8 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
enable_multimodal_encoder_data_parallel=self.
|
||||
enable_multimodal_encoder_data_parallel,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -49,6 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -84,23 +86,29 @@ class Llama4ImagePatchInputs(TypedDict):
|
||||
|
||||
class Llama4VisionMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
bias: bool,
|
||||
output_activation: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
bias: bool,
|
||||
output_activation: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
||||
self.fc2 = cls_fc2(
|
||||
input_size=intermediate_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
int(channels / shuffle_ratio))
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
reshaped_tensor = reshaped_tensor.view(batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)))
|
||||
reshaped_tensor = reshaped_tensor.view(
|
||||
batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)),
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1,
|
||||
@ -173,6 +183,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
@ -186,7 +197,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
bias=config.multi_modal_projector_bias,
|
||||
output_activation=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = pixel_shuffle(encoded_patches,
|
||||
@ -201,10 +214,12 @@ class Llama4VisionAttention(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
@ -217,22 +232,39 @@ class Llama4VisionAttention(nn.Module):
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
|
||||
self.scaling)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = ReplicatedLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
@ -275,22 +307,29 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = Llama4VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=True,
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.self_attn = Llama4VisionAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.mlp = Llama4VisionMLP(
|
||||
input_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=True,
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
@ -322,6 +361,7 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -330,6 +370,7 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
) for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -357,23 +398,33 @@ class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = config.patch_size
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
self.linear = ColumnParallelLinear(config.num_channels *
|
||||
kernel_size[0] * kernel_size[1],
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.linear")
|
||||
params = {
|
||||
"input_size":
|
||||
config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
"output_size": config.hidden_size,
|
||||
"bias": False,
|
||||
"quant_config": quant_config,
|
||||
"prefix": f"{prefix}.linear",
|
||||
}
|
||||
if use_data_parallel:
|
||||
cls = ReplicatedLinear
|
||||
else:
|
||||
cls = ColumnParallelLinear
|
||||
params["gather_output"] = True
|
||||
self.linear = cls(**params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
@ -389,6 +440,7 @@ class Llama4VisionModel(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -403,7 +455,9 @@ class Llama4VisionModel(nn.Module):
|
||||
self.patch_embedding = Llama4UnfoldConvolution(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.patch_embedding")
|
||||
prefix=f"{prefix}.patch_embedding",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(self.scale *
|
||||
torch.randn(self.hidden_size))
|
||||
@ -415,11 +469,18 @@ class Llama4VisionModel(nn.Module):
|
||||
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
|
||||
# encoders
|
||||
self.model = Llama4VisionEncoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model")
|
||||
self.model = Llama4VisionEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
||||
config, quant_config, prefix=f"{prefix}.vision_adapter")
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision_adapter",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -528,8 +589,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
vision_config = self.info.get_hf_config().vision_config
|
||||
|
||||
if processed_outputs.get("pixel_values") is not None:
|
||||
assert "images" in mm_data, \
|
||||
"images expected to be in mm_data when pixel_values is present"
|
||||
assert (
|
||||
"images" in mm_data
|
||||
), "images expected to be in mm_data when pixel_values is present"
|
||||
|
||||
images = mm_data["images"]
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
@ -546,8 +608,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
get_best_fit(
|
||||
(image.size[1], image.size[0]),
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas)
|
||||
for image in parsed_images
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||
) for image in parsed_images
|
||||
]
|
||||
# TODO tile height/width do not necessarily need to match
|
||||
aspect_ratios = [(image_size[0] // tile_size,
|
||||
@ -659,13 +721,17 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.vision_model = Llama4VisionModel(config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_model"))
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
@ -709,7 +775,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
# shard image input
|
||||
if self.use_data_parallel:
|
||||
vision_embeddings_flat = run_dp_sharded_vision_model(
|
||||
flat_data, self.vision_model)
|
||||
else:
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
|
||||
vision_embeddings_flat = self.multi_modal_projector(
|
||||
vision_embeddings_flat)
|
||||
|
||||
@ -796,6 +868,30 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return get_prefix_weights(), get_other_weights()
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
@ -818,9 +914,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
assert loaded_language_model_params is not None
|
||||
updated_params.update(loaded_language_model_params)
|
||||
|
||||
if self.use_data_parallel:
|
||||
other_weights = self._consolidate_qkv_weights(other_weights)
|
||||
|
||||
for name, loaded_weight in other_weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
if weight_name not in name or self.use_data_parallel:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
|
||||
@ -12,6 +12,9 @@ from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
@ -390,3 +393,35 @@ def group_mm_inputs_by_modality(
|
||||
return [
|
||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||
]
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
will shard the input image tensor on the first dimension and run the vision
|
||||
model
|
||||
|
||||
Args:
|
||||
image_input (torch.Tensor): Image input tensor.
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
"""
|
||||
|
||||
num_chunks = image_input.shape[0]
|
||||
mp_world_size = get_tensor_model_parallel_world_size()
|
||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
image_input_per_rank = image_input_padded[rank *
|
||||
num_chunks_per_rank:(rank + 1) *
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
return vision_embeddings
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user