[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
WeiQing Chen 2025-09-02 00:56:56 +08:00 committed by GitHub
parent cf91a89dd2
commit a0e0efd6bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 156 additions and 61 deletions

View File

@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
Known supported models: Known supported models:
- Kimi-VL (<gh-pr:23817>)
- Llama4 (<gh-pr:18368>) - Llama4 (<gh-pr:18368>)
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>) - MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
- Qwen2.5-VL (<gh-pr:22742>) - Qwen2.5-VL (<gh-pr:22742>)

View File

@ -636,8 +636,10 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
# Run the model through the sharded function # Run the model through the sharded function
with torch.inference_mode(): with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model( sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
vision_model, pixel_values, grid_thw_list) pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0) sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is setup correctly # Check that the world size is setup correctly
@ -691,8 +693,10 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
# Should handle empty input gracefully # Should handle empty input gracefully
with torch.inference_mode(): with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, output = run_dp_sharded_mrope_vision_model(vision_model,
grid_thw_list) pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0 assert len(output) == 0
@ -745,8 +749,10 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
# Should handle uneven distribution without errors # Should handle uneven distribution without errors
with torch.inference_mode(): with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model( output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
vision_model, pixel_values, grid_thw_list) pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable # Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2 merge_factor = vision_model.spatial_merge_size**2

View File

@ -56,6 +56,7 @@ from transformers.activations import GELUActivation
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
@ -76,6 +77,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
@ -93,8 +95,10 @@ class MaxImageTokenMeta:
class KimiVLMultiModalProjector(nn.Module): class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig): def __init__(self, config: KimiVLConfig, \
use_data_parallel: bool = False, prefix: str = ""):
super().__init__() super().__init__()
self.use_data_parallel = use_data_parallel
self.hidden_size = (config.vision_config.hidden_size * self.hidden_size = (config.vision_config.hidden_size *
config.vision_config.merge_kernel_size[0] * config.vision_config.merge_kernel_size[0] *
@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module):
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
eps=1e-5) eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size, self.linear_1 = ReplicatedLinear(self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True) bias=True,
prefix=maybe_prefix(
prefix, "linear_1"))
self.linear_2 = ReplicatedLinear(self.hidden_size,
config.text_config.hidden_size,
bias=True,
prefix=maybe_prefix(
prefix, "linear_2"))
self.act = GELUActivation() self.act = GELUActivation()
self.linear_2 = nn.Linear(self.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor: def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view( hidden_states = self.pre_norm(image_features).view(
-1, self.hidden_size) -1, self.hidden_size)
hidden_states = self.linear_1(hidden_states) hidden_states, _ = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states) hidden_states, _ = self.linear_2(hidden_states)
return hidden_states return hidden_states
@ -273,6 +281,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@ -292,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
assert isinstance(config.vision_config, MoonViTConfig) assert isinstance(config.vision_config, MoonViTConfig)
self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data"
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(config.vision_config,
self.use_data_parallel,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.vision_tower = MoonVitPretrainedModel(config.vision_config) self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
self.multi_modal_projector = KimiVLMultiModalProjector(config=config) use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.quant_config = quant_config self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config) sub_vllm_config = copy.deepcopy(vllm_config)
@ -376,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"] image_grid_hws = inputs["image_grid_hws"]
return self.vision_tower(pixel_values, image_grid_hws) if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.vision_tower,
pixel_values,
image_grid_hws.tolist(),
rope_type="rope_2d")
else:
return self.vision_tower(pixel_values, image_grid_hws)
def _process_image_input(self, def _process_image_input(self,
image_input: KimiVLImageInputs) -> torch.Tensor: image_input: KimiVLImageInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values" assert image_input["type"] == "pixel_values"
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
assert isinstance(image_features, list) assert isinstance(image_features, (list, tuple))
lengths = [x.shape[0] for x in image_features] lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector( return self.multi_modal_projector(
torch.cat(image_features)).split(lengths) torch.cat(image_features)).split(lengths)
@ -496,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
expert_params_mapping = [] expert_params_mapping = []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for args in weights: for args in weights:
name, loaded_weight = args[:2] name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {} kwargs = args[2] if len(args) > 2 else {}

View File

@ -42,7 +42,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import math
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from functools import cached_property from functools import cached_property
@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
@ -383,21 +384,30 @@ class MLP2(nn.Module):
bias: whether to use bias in linear layer. bias: whether to use bias in linear layer.
""" """
def __init__(self, dims: list[int], activation, bias=True): def __init__(self,
dims: list[int],
activation,
bias=True,
prefix: str = "",
use_data_parallel: bool = False):
super().__init__() super().__init__()
assert len(dims) == 3 assert len(dims) == 3
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) self.use_data_parallel = use_data_parallel
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) self.fc0 = ReplicatedLinear(dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"))
self.fc1 = ReplicatedLinear(dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"))
self.activation = activation self.activation = activation
for m in [self.fc0, self.fc1]:
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc0(x) x, _ = self.fc0(x)
x = self.activation(x) x = self.activation(x)
return self.fc1(x) x, _ = self.fc1(x)
return x
class MoonVitEncoderLayer(nn.Module): class MoonVitEncoderLayer(nn.Module):
@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module):
num_heads: int, num_heads: int,
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
prefix: str = "",
use_data_parallel: bool = False,
*, *,
attn_implementation: str = "sdpa", attn_implementation: str = "sdpa",
activation=F.gelu, activation=F.gelu,
@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module):
self.norm0 = nn.LayerNorm(hidden_dim) self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) self.use_data_parallel = use_data_parallel
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim],
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) activation,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.wqkv = ReplicatedLinear(hidden_dim,
hidden_dim * 3,
bias=attn_bias,
prefix=f"{prefix}.wqkv")
self.wo = ReplicatedLinear(hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo")
def attention_qkvpacked( def attention_qkvpacked(
self, self,
@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module):
x (torch.Tensor): (batch_size, seqlen, hidden_dim) x (torch.Tensor): (batch_size, seqlen, hidden_dim)
cu_seqlens (torch.Tensor): cu_seqlens (torch.Tensor):
""" """
xqkv = self.wqkv(x) xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + ( qkv_shape = xqkv.size()[:-1] + (
3, 3,
@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module):
xv, xv,
q_cu_seqlens=cu_seqlens, q_cu_seqlens=cu_seqlens,
k_cu_seqlens=cu_seqlens) k_cu_seqlens=cu_seqlens)
attn_out, _ = self.wo(attn_out)
attn_out = self.wo(attn_out)
return attn_out return attn_out
def forward( def forward(
@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module):
hidden_dim: int, hidden_dim: int,
num_layers: int, num_layers: int,
block_cfg: dict, block_cfg: dict,
prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.rope_2d = Rope2DPosEmb( self.rope_2d = Rope2DPosEmb(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) [MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \
prefix=f"{prefix}.blocks.{layer_idx}", \
**block_cfg) for layer_idx in range(num_layers)])
self.final_layernorm = nn.LayerNorm(hidden_dim) self.final_layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module):
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
grid_hws=grid_hw) grid_hws=grid_hw)
lengths = torch.cat(( lengths = torch.cat(
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), (torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
grid_hw[:, 0] * grid_hw[:, 1], (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device)))
))
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
for _, block in enumerate(self.blocks): for _, block in enumerate(self.blocks):
@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs): def __init__(self,
config: MoonViTConfig,
use_data_parallel: bool = False,
prefix: str = "",
*inputs,
**kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
config = deepcopy(config) config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.vit_processing_type = "rope_2d"
self.patch_embed = MoonVisionPatchEmbed( self.patch_embed = MoonVisionPatchEmbed(
out_dim=config.hidden_size, out_dim=config.hidden_size,
patch_size=config.patch_size, patch_size=config.patch_size,
@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
"attn_bias": True, "attn_bias": True,
"attn_implementation": config._attn_implementation, "attn_implementation": config._attn_implementation,
}, },
prefix=f"{prefix}.encoder",
) )
def forward(self, pixel_values: torch.Tensor, def forward(self, pixel_values: torch.Tensor,

View File

@ -1021,8 +1021,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
if self.use_data_parallel: if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(self.visual,
self.visual, pixel_values, grid_thw_list) pixel_values,
grid_thw_list,
rope_type="rope_3d")
else: else:
image_embeds = self.visual(pixel_values, image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list) grid_thw=grid_thw_list)
@ -1048,8 +1050,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
else: else:
pixel_values_videos = video_input["pixel_values_videos"] pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel: if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(self.visual,
self.visual, pixel_values_videos, grid_thw_list) pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else: else:
video_embeds = self.visual(pixel_values_videos, video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list) grid_thw=grid_thw_list)

View File

@ -9,7 +9,7 @@ from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse from urllib.parse import ParseResult, urlparse
from urllib.request import url2pathname from urllib.request import url2pathname
@ -444,7 +444,6 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
Args: Args:
image_input (torch.Tensor): Image input tensor. image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model. vision_model (torch.nn.Module): Vision model.
Returns: Returns:
torch.Tensor: Output image embeddings torch.Tensor: Output image embeddings
""" """
@ -542,6 +541,8 @@ def run_dp_sharded_mrope_vision_model(
vision_model: torch.nn.Module, vision_model: torch.nn.Module,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
grid_thw_list: list[list[int]], grid_thw_list: list[list[int]],
*,
rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
"""Run a vision model with data parallelism (DP) sharding. """Run a vision model with data parallelism (DP) sharding.
The function will shard the input image tensor on the The function will shard the input image tensor on the
@ -552,6 +553,10 @@ def run_dp_sharded_mrope_vision_model(
vision_model (torch.nn.Module): Vision model. vision_model (torch.nn.Module): Vision model.
pixel_values (torch.Tensor): Image/Video input tensor. pixel_values (torch.Tensor): Image/Video input tensor.
grid_thw_list: List of grid dimensions for each image grid_thw_list: List of grid dimensions for each image
rope_type: Type of rope used in the vision model.
Different rope types have different dimension to do ViT.
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
"rope_2d" for 2D rope (e.g., Kimi-VL)
Returns: Returns:
torch.Tensor: Output image embeddings torch.Tensor: Output image embeddings
@ -605,8 +610,12 @@ def run_dp_sharded_mrope_vision_model(
device=pixel_values.device, device=pixel_values.device,
dtype=pixel_values.dtype) dtype=pixel_values.dtype)
# embed_dim_reduction_factor = 2 * 2 # embed_dim_reduction_factor = 2 * 2
embed_dim_reduction_factor = (vision_model.spatial_merge_size * if rope_type == "rope_2d":
vision_model.spatial_merge_size) embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
vision_model.merge_kernel_size[1])
else:
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
vision_model.spatial_merge_size)
# Find the max length across all ranks # Find the max length across all ranks
# The output embedding of every DP rank has to be # The output embedding of every DP rank has to be
@ -617,23 +626,42 @@ def run_dp_sharded_mrope_vision_model(
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
# Run the vision model on the local pixel_values_local # Run the vision model on the local pixel_values_local
if pixel_values_local.shape[0] > 0: if rope_type == "rope_2d":
image_embeds_local = vision_model(pixel_values_local, if pixel_values_local.shape[0] > 0:
local_grid_thw_list) image_embeds_local = vision_model(
pixel_values_local, torch.tensor(local_grid_thw_list))
if isinstance(image_embeds_local, list):
image_embeds_local = torch.cat(image_embeds_local, dim=0)
else:
out_dim = getattr(vision_model.config, "hidden_size", None)
image_embeds_local = torch.empty(
(0, embed_dim_reduction_factor, out_dim),
device=pixel_values.device,
dtype=pixel_values.dtype)
else: else:
# Handle empty case if pixel_values_local.shape[0] > 0:
image_embeds_local = torch.empty((0, vision_model.out_hidden_size), image_embeds_local = vision_model(pixel_values_local,
device=pixel_values.device, local_grid_thw_list)
dtype=pixel_values.dtype) else:
# Handle empty case
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
# Pad the output based on max_len_per_rank # Pad the output based on max_len_per_rank
# for tensor_model_parallel_all_gather to work # for tensor_model_parallel_all_gather to work
current_len = image_embeds_local.shape[0] current_len = image_embeds_local.shape[0]
if current_len < max_len_per_rank: if current_len < max_len_per_rank:
padding_size = max_len_per_rank - current_len padding_size = max_len_per_rank - current_len
padding = torch.empty((padding_size, image_embeds_local.shape[1]), if rope_type == "rope_2d":
dtype=image_embeds_local.dtype, padding = torch.empty((padding_size, image_embeds_local.shape[1],
device=image_embeds_local.device) image_embeds_local.shape[2]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
else:
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
image_embeds_local_padded = torch.cat([image_embeds_local, padding], image_embeds_local_padded = torch.cat([image_embeds_local, padding],
dim=0) dim=0)
else: else:
@ -674,7 +702,6 @@ def run_dp_sharded_mrope_vision_model(
embed_start:embed_start + img_patches] embed_start:embed_start + img_patches]
embed_start += img_patches embed_start += img_patches
current_idx += count current_idx += count
out_embeddings = tuple(embed for embed in original_order_embeddings out_embeddings = tuple(embed for embed in original_order_embeddings
if embed is not None) if embed is not None)
assert len(out_embeddings) == len( assert len(out_embeddings) == len(