[Model] enable data parallel for Llama4 vision encoder (#18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
jennyyyyzhen 2025-06-02 04:22:54 -07:00 committed by GitHub
parent 5b168b6d7a
commit ebb1ec9318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 214 additions and 68 deletions

View File

@ -1790,6 +1790,10 @@ class ParallelConfig:
rank: int = 0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world

View File

@ -423,6 +423,9 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@ -637,6 +640,9 @@ class EngineArgs:
**parallel_kwargs["worker_cls"])
parallel_group.add_argument("--worker-extension-cls",
**parallel_kwargs["worker_extension_cls"])
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
@ -1078,6 +1084,8 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
)
speculative_config = self.create_speculative_config(

View File

@ -34,6 +34,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
@ -49,6 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -84,23 +86,29 @@ class Llama4ImagePatchInputs(TypedDict):
class Llama4VisionMLP(nn.Module):
def __init__(self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.fc1 = ColumnParallelLinear(
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
self.fc2 = cls_fc2(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
int(channels / shuffle_ratio))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)))
reshaped_tensor = reshaped_tensor.view(
batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1,
@ -173,6 +183,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
@ -186,7 +197,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
bias=config.multi_modal_projector_bias,
output_activation=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches,
@ -201,10 +214,12 @@ class Llama4VisionAttention(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
@ -217,22 +232,39 @@ class Llama4VisionAttention(nn.Module):
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
self.scaling)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
self.q_size + 2 * self.kv_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = ReplicatedLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
@ -275,22 +307,29 @@ class Llama4VisionEncoderLayer(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.intermediate_size = config.intermediate_size
self.self_attn = Llama4VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.self_attn = Llama4VisionAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
)
self.mlp = Llama4VisionMLP(
input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
@ -322,6 +361,7 @@ class Llama4VisionEncoder(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@ -330,6 +370,7 @@ class Llama4VisionEncoder(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) for layer_idx in range(config.num_hidden_layers)
])
@ -357,23 +398,33 @@ class Llama4VisionEncoder(nn.Module):
class Llama4UnfoldConvolution(nn.Module):
def __init__(self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
self.linear = ColumnParallelLinear(config.num_channels *
kernel_size[0] * kernel_size[1],
config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear")
params = {
"input_size":
config.num_channels * kernel_size[0] * kernel_size[1],
"output_size": config.hidden_size,
"bias": False,
"quant_config": quant_config,
"prefix": f"{prefix}.linear",
}
if use_data_parallel:
cls = ReplicatedLinear
else:
cls = ColumnParallelLinear
params["gather_output"] = True
self.linear = cls(**params)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
@ -389,6 +440,7 @@ class Llama4VisionModel(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@ -403,7 +455,9 @@ class Llama4VisionModel(nn.Module):
self.patch_embedding = Llama4UnfoldConvolution(
config,
quant_config=quant_config,
prefix=f"{prefix}.patch_embedding")
prefix=f"{prefix}.patch_embedding",
use_data_parallel=use_data_parallel,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
@ -415,11 +469,18 @@ class Llama4VisionModel(nn.Module):
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
# encoders
self.model = Llama4VisionEncoder(config,
quant_config=quant_config,
prefix=f"{prefix}.model")
self.model = Llama4VisionEncoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.model",
use_data_parallel=use_data_parallel,
)
self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, quant_config, prefix=f"{prefix}.vision_adapter")
config,
quant_config,
prefix=f"{prefix}.vision_adapter",
use_data_parallel=use_data_parallel,
)
def forward(
self,
@ -528,8 +589,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
vision_config = self.info.get_hf_config().vision_config
if processed_outputs.get("pixel_values") is not None:
assert "images" in mm_data, \
"images expected to be in mm_data when pixel_values is present"
assert (
"images" in mm_data
), "images expected to be in mm_data when pixel_values is present"
images = mm_data["images"]
parsed_images = (self._get_data_parser().parse_mm_data({
@ -546,8 +608,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
get_best_fit(
(image.size[1], image.size[0]),
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas)
for image in parsed_images
resize_to_max_canvas=image_processor.resize_to_max_canvas,
) for image in parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios = [(image_size[0] // tile_size,
@ -659,13 +721,17 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_model = Llama4VisionModel(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.vision_model = Llama4VisionModel(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
@ -709,7 +775,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data)
# shard image input
if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model)
else:
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
@ -796,6 +868,30 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_prefix_weights(), get_other_weights()
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
@ -818,9 +914,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights)
for name, loaded_weight in other_weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
if weight_name not in name or self.use_data_parallel:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]

View File

@ -12,6 +12,9 @@ from PIL import Image
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from .audio import AudioMediaIO
from .base import MediaIO
@ -390,3 +393,35 @@ def group_mm_inputs_by_modality(
return [
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
]
def run_dp_sharded_vision_model(image_input: torch.Tensor,
vision_model: torch.nn.Module) -> torch.Tensor:
"""Run a vision model with data parallelism (DP) sharding. The function
will shard the input image tensor on the first dimension and run the vision
model
Args:
image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model.
Returns:
torch.Tensor: Output image embeddings
"""
num_chunks = image_input.shape[0]
mp_world_size = get_tensor_model_parallel_world_size()
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
image_input_padded = torch.nn.functional.pad(image_input, pad)
rank = get_tensor_model_parallel_rank()
image_input_per_rank = image_input_padded[rank *
num_chunks_per_rank:(rank + 1) *
num_chunks_per_rank, ...]
vision_embeddings = vision_model(image_input_per_rank)
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
dim=0)
vision_embeddings = vision_embeddings[:num_chunks, ...]
return vision_embeddings