diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 2d8cdcc11fa99..b0ea9621d545a 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u Known supported models: +- Kimi-VL () - Llama4 () - MiniCPM-V-2.5 or above (, ) - Qwen2.5-VL () diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index a028c668c8ab7..05e68a961a548 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -636,8 +636,10 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, # Run the model through the sharded function with torch.inference_mode(): - sharded_output = run_dp_sharded_mrope_vision_model( - vision_model, pixel_values, grid_thw_list) + sharded_output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") sharded_output = torch.cat(sharded_output, dim=0) # Check that the world size is setup correctly @@ -691,8 +693,10 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker( # Should handle empty input gracefully with torch.inference_mode(): - output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, - grid_thw_list) + output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") assert len(output) == 0 @@ -745,8 +749,10 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( # Should handle uneven distribution without errors with torch.inference_mode(): - output_tuple = run_dp_sharded_mrope_vision_model( - vision_model, pixel_values, grid_thw_list) + output_tuple = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") # Verify output shape is reasonable merge_factor = vision_model.spatial_merge_size**2 diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index a08a9a62a57c5..4f76d4afdb20e 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -56,6 +56,7 @@ from transformers.activations import GELUActivation from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) @@ -76,6 +77,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config @@ -93,8 +95,10 @@ class MaxImageTokenMeta: class KimiVLMultiModalProjector(nn.Module): - def __init__(self, config: KimiVLConfig): + def __init__(self, config: KimiVLConfig, \ + use_data_parallel: bool = False, prefix: str = ""): super().__init__() + self.use_data_parallel = use_data_parallel self.hidden_size = (config.vision_config.hidden_size * config.vision_config.merge_kernel_size[0] * @@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module): self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5) - self.linear_1 = nn.Linear(self.hidden_size, - self.hidden_size, - bias=True) + self.linear_1 = ReplicatedLinear(self.hidden_size, + self.hidden_size, + bias=True, + prefix=maybe_prefix( + prefix, "linear_1")) + self.linear_2 = ReplicatedLinear(self.hidden_size, + config.text_config.hidden_size, + bias=True, + prefix=maybe_prefix( + prefix, "linear_2")) self.act = GELUActivation() - self.linear_2 = nn.Linear(self.hidden_size, - config.text_config.hidden_size, - bias=True) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.pre_norm(image_features).view( -1, self.hidden_size) - hidden_states = self.linear_1(hidden_states) + hidden_states, _ = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -273,6 +281,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -292,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config = vllm_config.quant_config assert isinstance(config.vision_config, MoonViTConfig) + self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data" + self.hidden_size = config.text_config.hidden_size + self.vision_tower = MoonVitPretrainedModel(config.vision_config, + self.use_data_parallel, + prefix=maybe_prefix( + prefix, "vision_tower")) - self.vision_tower = MoonVitPretrainedModel(config.vision_config) - - self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + self.multi_modal_projector = KimiVLMultiModalProjector( + config=config, + use_data_parallel=self.use_data_parallel, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.quant_config = quant_config sub_vllm_config = copy.deepcopy(vllm_config) @@ -376,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values = inputs["pixel_values"] image_grid_hws = inputs["image_grid_hws"] - return self.vision_tower(pixel_values, image_grid_hws) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.vision_tower, + pixel_values, + image_grid_hws.tolist(), + rope_type="rope_2d") + else: + return self.vision_tower(pixel_values, image_grid_hws) def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor: assert image_input["type"] == "pixel_values" image_features = self._process_image_pixels(image_input) - assert isinstance(image_features, list) + assert isinstance(image_features, (list, tuple)) lengths = [x.shape[0] for x in image_features] return self.multi_modal_projector( torch.cat(image_features)).split(lengths) @@ -496,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, expert_params_mapping = [] params_dict = dict(self.named_parameters()) + for args in weights: name, loaded_weight = args[:2] kwargs = args[2] if len(args) > 2 else {} diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index d0fdab13ef0c9..41a2c836b09f3 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -42,7 +42,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import math from collections.abc import Sequence from copy import deepcopy from functools import cached_property @@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh from transformers.modeling_utils import PreTrainedModel from transformers.utils import is_flash_attn_2_available +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.utils import maybe_prefix from vllm.transformers_utils.configs.moonvit import MoonViTConfig if is_flash_attn_2_available(): @@ -383,21 +384,30 @@ class MLP2(nn.Module): bias: whether to use bias in linear layer. """ - def __init__(self, dims: list[int], activation, bias=True): + def __init__(self, + dims: list[int], + activation, + bias=True, + prefix: str = "", + use_data_parallel: bool = False): super().__init__() assert len(dims) == 3 - self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) - self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.use_data_parallel = use_data_parallel + self.fc0 = ReplicatedLinear(dims[0], + dims[1], + bias=bias, + prefix=maybe_prefix(prefix, "fc0")) + self.fc1 = ReplicatedLinear(dims[1], + dims[2], + bias=bias, + prefix=maybe_prefix(prefix, "fc1")) self.activation = activation - for m in [self.fc0, self.fc1]: - nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) - if m.bias is not None: - nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc0(x) + x, _ = self.fc0(x) x = self.activation(x) - return self.fc1(x) + x, _ = self.fc1(x) + return x class MoonVitEncoderLayer(nn.Module): @@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module): num_heads: int, hidden_dim: int, mlp_dim: int, + prefix: str = "", + use_data_parallel: bool = False, *, attn_implementation: str = "sdpa", activation=F.gelu, @@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module): self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) - self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) - self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + self.use_data_parallel = use_data_parallel + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], + activation, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + self.wqkv = ReplicatedLinear(hidden_dim, + hidden_dim * 3, + bias=attn_bias, + prefix=f"{prefix}.wqkv") + self.wo = ReplicatedLinear(hidden_dim, + hidden_dim, + bias=attn_bias, + prefix=f"{prefix}.wo") def attention_qkvpacked( self, @@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module): x (torch.Tensor): (batch_size, seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ - xqkv = self.wqkv(x) + xqkv, _ = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( 3, @@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module): xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens) - - attn_out = self.wo(attn_out) + attn_out, _ = self.wo(attn_out) return attn_out def forward( @@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module): hidden_dim: int, num_layers: int, block_cfg: dict, + prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.rope_2d = Rope2DPosEmb( block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) self.blocks = nn.ModuleList( - [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + [MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \ + prefix=f"{prefix}.blocks.{layer_idx}", \ + **block_cfg) for layer_idx in range(num_layers)]) self.final_layernorm = nn.LayerNorm(hidden_dim) def forward(self, hidden_states: torch.Tensor, @@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module): rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( grid_hws=grid_hw) - lengths = torch.cat(( - torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), - grid_hw[:, 0] * grid_hw[:, 1], - )) + lengths = torch.cat( + (torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device))) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) for _, block in enumerate(self.blocks): @@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + def __init__(self, + config: MoonViTConfig, + use_data_parallel: bool = False, + prefix: str = "", + *inputs, + **kwargs): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) + self.use_data_parallel = use_data_parallel self.merge_kernel_size = config.merge_kernel_size + self.hidden_size = config.hidden_size self.patch_size = config.patch_size + self.vit_processing_type = "rope_2d" self.patch_embed = MoonVisionPatchEmbed( out_dim=config.hidden_size, patch_size=config.patch_size, @@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel): "attn_bias": True, "attn_implementation": config._attn_implementation, }, + prefix=f"{prefix}.encoder", ) def forward(self, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b528083b7c9cc..c8f7fc16b4e83 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1021,8 +1021,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values = image_input["pixel_values"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") else: image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) @@ -1048,8 +1050,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, else: pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") else: video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 834b2189e4bed..ac967dcc4003e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse from urllib.request import url2pathname @@ -444,7 +444,6 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, Args: image_input (torch.Tensor): Image input tensor. vision_model (torch.nn.Module): Vision model. - Returns: torch.Tensor: Output image embeddings """ @@ -542,6 +541,8 @@ def run_dp_sharded_mrope_vision_model( vision_model: torch.nn.Module, pixel_values: torch.Tensor, grid_thw_list: list[list[int]], + *, + rope_type: Literal["rope_3d", "rope_2d"], ) -> tuple[torch.Tensor, ...]: """Run a vision model with data parallelism (DP) sharding. The function will shard the input image tensor on the @@ -552,6 +553,10 @@ def run_dp_sharded_mrope_vision_model( vision_model (torch.nn.Module): Vision model. pixel_values (torch.Tensor): Image/Video input tensor. grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) Returns: torch.Tensor: Output image embeddings @@ -605,8 +610,12 @@ def run_dp_sharded_mrope_vision_model( device=pixel_values.device, dtype=pixel_values.dtype) # embed_dim_reduction_factor = 2 * 2 - embed_dim_reduction_factor = (vision_model.spatial_merge_size * - vision_model.spatial_merge_size) + if rope_type == "rope_2d": + embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * + vision_model.merge_kernel_size[1]) + else: + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) # Find the max length across all ranks # The output embedding of every DP rank has to be @@ -617,23 +626,42 @@ def run_dp_sharded_mrope_vision_model( local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] # Run the vision model on the local pixel_values_local - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, - local_grid_thw_list) + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list)) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype) else: - # Handle empty case - image_embeds_local = torch.empty((0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) # Pad the output based on max_len_per_rank # for tensor_model_parallel_all_gather to work current_len = image_embeds_local.shape[0] if current_len < max_len_per_rank: padding_size = max_len_per_rank - current_len - padding = torch.empty((padding_size, image_embeds_local.shape[1]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) + if rope_type == "rope_2d": + padding = torch.empty((padding_size, image_embeds_local.shape[1], + image_embeds_local.shape[2]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + else: + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0) else: @@ -674,7 +702,6 @@ def run_dp_sharded_mrope_vision_model( embed_start:embed_start + img_patches] embed_start += img_patches current_idx += count - out_embeddings = tuple(embed for embed in original_order_embeddings if embed is not None) assert len(out_embeddings) == len(