mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 05:20:54 +08:00
[Misc] Move DP for ViT code inside model executor dir (#25459)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
c4a15ee240
commit
215da8510d
@ -1,10 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.models.vision import (
|
||||
get_load_balance_assignment, resolve_visual_encoder_outputs,
|
||||
run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
|
||||
post_layer_norm=None,
|
||||
max_possible_layers=max_possible_layers)
|
||||
assert torch.equal(torch.tensor(expected_features), output_tensor)
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
"""A simple linear vision model for testing."""
|
||||
|
||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Flatten the input and apply linear transformation
|
||||
x = self.flatten(x)
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
4, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
batch_size: int, master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
|
||||
# Create a simple linear model
|
||||
vision_model = SimpleLinearModel()
|
||||
|
||||
# Run the model directly on the full input
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(image_input)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||
"expected_grouped_sizes_per_gpu,test_description",
|
||||
[
|
||||
# Empty input
|
||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||
|
||||
# Fewer samples than GPUs
|
||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||
], "fewer samples than GPUs"),
|
||||
|
||||
# Single GPU
|
||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||
|
||||
# Balanced assignment
|
||||
([100, 100, 100, 100
|
||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||
|
||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||
],
|
||||
)
|
||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||
expected_shuffle_indices,
|
||||
expected_gpu_sample_counts,
|
||||
expected_grouped_sizes_per_gpu,
|
||||
test_description):
|
||||
"""Test get_load_balance_assignment with various input cases."""
|
||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||
|
||||
# Common assertions for all cases
|
||||
assert len(shuffle_indices) == len(sizes)
|
||||
assert len(gpu_sample_counts) == num_gpus
|
||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||
assert sum(gpu_sample_counts) == len(sizes)
|
||||
|
||||
assert shuffle_indices == expected_shuffle_indices
|
||||
|
||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||
|
||||
|
||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||
"""A simple vision model for testing mrope functionality."""
|
||||
|
||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]]):
|
||||
"""Simple forward pass that simulates spatial merging."""
|
||||
# Apply linear transformation
|
||||
embeddings = self.linear(pixel_values)
|
||||
|
||||
# Simulate spatial merging by reducing the number of patches
|
||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
# Group patches and merge spatially
|
||||
merged_embeddings = []
|
||||
start_idx = 0
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
end_idx = start_idx + num_patches
|
||||
|
||||
# Get patches for this image
|
||||
image_patches = embeddings[start_idx:end_idx]
|
||||
|
||||
# Simulate spatial merging by averaging groups of patches
|
||||
merged_patches = num_patches // merge_factor
|
||||
if merged_patches > 0:
|
||||
# Reshape and average to simulate merging
|
||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||
merged_patches, merge_factor, -1)
|
||||
merged = reshaped.mean(dim=1)
|
||||
merged_embeddings.append(merged)
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
if merged_embeddings:
|
||||
return torch.cat(merged_embeddings, dim=0)
|
||||
else:
|
||||
return torch.empty((0, self.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
3, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
pixel_values_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Varying image sizes for better testing
|
||||
t, h, w = 1, 4 + i, 4 + i
|
||||
grid_thw_list.append([t, h, w])
|
||||
|
||||
num_patches = t * h * w
|
||||
# Create random pixel values for this image
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
# Concatenate all pixel values
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
|
||||
# Create a simple mrope vision model
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Run the model directly on the full input (only on rank 0)
|
||||
if local_rank == 0:
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Compare outputs (only on rank 0)
|
||||
if local_rank == 0:
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output,
|
||||
sharded_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||
world_size = 2
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||
# Set up distributed environment
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
grid_thw_list: list[list[int]] = []
|
||||
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||
world_size = 4
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||
# Set up distributed environment
|
||||
current_platform.seed_everything(123)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
[1, 2, 2], # Small: 4 patches
|
||||
[1, 8, 8], # Large: 64 patches
|
||||
[1, 3, 3], # Medium: 9 patches
|
||||
]
|
||||
|
||||
pixel_values_list = []
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
expected_output_patches = list(
|
||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||
|
||||
for i, output in enumerate(output_tuple):
|
||||
assert output.shape[0] == expected_output_patches[i]
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||
device = current_platform.device_type
|
||||
|
||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||
pixel_values_list = []
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel(
|
||||
spatial_merge_size=spatial_merge_size).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output dimensions based on spatial merging
|
||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||
merge_factor = spatial_merge_size**2
|
||||
expected_output_patches = total_patches // merge_factor
|
||||
|
||||
assert output.shape[0] == expected_output_patches
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
||||
get_load_balance_assignment,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
@ -404,415 +392,3 @@ def test_argsort_mm_positions():
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
"""A simple linear vision model for testing."""
|
||||
|
||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Flatten the input and apply linear transformation
|
||||
x = self.flatten(x)
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
4, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
batch_size: int, master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
|
||||
# Create a simple linear model
|
||||
vision_model = SimpleLinearModel()
|
||||
|
||||
# Run the model directly on the full input
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(image_input)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||
"expected_grouped_sizes_per_gpu,test_description",
|
||||
[
|
||||
# Empty input
|
||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||
|
||||
# Fewer samples than GPUs
|
||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||
], "fewer samples than GPUs"),
|
||||
|
||||
# Single GPU
|
||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||
|
||||
# Balanced assignment
|
||||
([100, 100, 100, 100
|
||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||
|
||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||
],
|
||||
)
|
||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||
expected_shuffle_indices,
|
||||
expected_gpu_sample_counts,
|
||||
expected_grouped_sizes_per_gpu,
|
||||
test_description):
|
||||
"""Test get_load_balance_assignment with various input cases."""
|
||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||
|
||||
# Common assertions for all cases
|
||||
assert len(shuffle_indices) == len(sizes)
|
||||
assert len(gpu_sample_counts) == num_gpus
|
||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||
assert sum(gpu_sample_counts) == len(sizes)
|
||||
|
||||
assert shuffle_indices == expected_shuffle_indices
|
||||
|
||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||
|
||||
|
||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||
"""A simple vision model for testing mrope functionality."""
|
||||
|
||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]]):
|
||||
"""Simple forward pass that simulates spatial merging."""
|
||||
# Apply linear transformation
|
||||
embeddings = self.linear(pixel_values)
|
||||
|
||||
# Simulate spatial merging by reducing the number of patches
|
||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
# Group patches and merge spatially
|
||||
merged_embeddings = []
|
||||
start_idx = 0
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
end_idx = start_idx + num_patches
|
||||
|
||||
# Get patches for this image
|
||||
image_patches = embeddings[start_idx:end_idx]
|
||||
|
||||
# Simulate spatial merging by averaging groups of patches
|
||||
merged_patches = num_patches // merge_factor
|
||||
if merged_patches > 0:
|
||||
# Reshape and average to simulate merging
|
||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||
merged_patches, merge_factor, -1)
|
||||
merged = reshaped.mean(dim=1)
|
||||
merged_embeddings.append(merged)
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
if merged_embeddings:
|
||||
return torch.cat(merged_embeddings, dim=0)
|
||||
else:
|
||||
return torch.empty((0, self.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
3, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
pixel_values_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Varying image sizes for better testing
|
||||
t, h, w = 1, 4 + i, 4 + i
|
||||
grid_thw_list.append([t, h, w])
|
||||
|
||||
num_patches = t * h * w
|
||||
# Create random pixel values for this image
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
# Concatenate all pixel values
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
|
||||
# Create a simple mrope vision model
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Run the model directly on the full input (only on rank 0)
|
||||
if local_rank == 0:
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is set up correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Compare outputs (only on rank 0)
|
||||
if local_rank == 0:
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output,
|
||||
sharded_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||
world_size = 2
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||
# Set up distributed environment
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
grid_thw_list: list[list[int]] = []
|
||||
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||
world_size = 4
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||
# Set up distributed environment
|
||||
current_platform.seed_everything(123)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
[1, 2, 2], # Small: 4 patches
|
||||
[1, 8, 8], # Large: 64 patches
|
||||
[1, 3, 3], # Medium: 9 patches
|
||||
]
|
||||
|
||||
pixel_values_list = []
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
expected_output_patches = list(
|
||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||
|
||||
for i, output in enumerate(output_tuple):
|
||||
assert output.shape[0] == expected_output_patches[i]
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||
device = current_platform.device_type
|
||||
|
||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||
pixel_values_list = []
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel(
|
||||
spatial_merge_size=spatial_merge_size).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output dimensions based on spatial merging
|
||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||
merge_factor = spatial_merge_size**2
|
||||
expected_output_patches = total_patches // merge_factor
|
||||
|
||||
assert output.shape[0] == expected_output_patches
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
@ -69,7 +69,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -83,7 +82,7 @@ from .qwen2_vl import (_create_qwen2vl_field_factory,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -34,7 +34,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
|
||||
@ -28,7 +28,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
|
||||
@ -76,13 +76,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
|
||||
from .vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
|
||||
# For dummy input only
|
||||
|
||||
@ -50,7 +50,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -58,6 +57,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Llama4ImagePatchInputs(TensorSchema):
|
||||
|
||||
@ -59,7 +59,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -74,7 +73,7 @@ from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -66,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -78,7 +77,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ from .qwen2_vl import Qwen2VLProcessingInfo
|
||||
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1214,8 +1214,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
if self.use_data_parallel:
|
||||
from vllm.multimodal.utils import (
|
||||
run_dp_sharded_mrope_vision_model)
|
||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
@ -1245,8 +1243,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
if self.use_data_parallel:
|
||||
from vllm.multimodal.utils import (
|
||||
run_dp_sharded_mrope_vision_model)
|
||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -40,6 +39,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Step3VLImagePixelInputs(TypedDict):
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs(
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
will shard the input image tensor on the first dimension and run the vision
|
||||
model
|
||||
|
||||
Args:
|
||||
image_input (torch.Tensor): Image input tensor.
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
"""
|
||||
|
||||
num_chunks = image_input.shape[0]
|
||||
mp_world_size = get_tensor_model_parallel_world_size()
|
||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
image_input_per_rank = image_input_padded[rank *
|
||||
num_chunks_per_rank:(rank + 1) *
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
vision_embeddings = vision_embeddings.contiguous()
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
return vision_embeddings
|
||||
|
||||
|
||||
def get_load_balance_assignment(
|
||||
sizes: list[int],
|
||||
num_gpus: int = 2,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Generate load balancing assignment and metadata
|
||||
for distributing data across GPUs.
|
||||
The load is determined by the total image sizes,
|
||||
not the number of images.
|
||||
|
||||
Args:
|
||||
sizes: The size of each image
|
||||
num_gpus: Number of GPUs to balance across
|
||||
|
||||
Returns:
|
||||
shuffle_indices:
|
||||
Indices to reorder data for balanced loading
|
||||
gpu_sample_counts:
|
||||
Number of samples assigned to each GPU
|
||||
grouped_sizes_per_gpu:
|
||||
Total size assigned to each GPU
|
||||
|
||||
Example:
|
||||
```
|
||||
sizes = [1000, 100, 200, 50]
|
||||
num_gpus=2
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
n_samples = len(sizes)
|
||||
|
||||
# Handle edge cases
|
||||
if n_samples == 0:
|
||||
return [], [0] * num_gpus, [0] * num_gpus
|
||||
|
||||
# Use greedy algorithm - balance by total size, not sample count
|
||||
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
||||
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
||||
|
||||
# Sort indices by size (largest first for better load balancing)
|
||||
# sizes = [1000, 100, 200, 50]
|
||||
# large_to_small_indices = [0, 2, 1, 3]
|
||||
large_to_small_indices = sorted(range(n_samples),
|
||||
key=lambda i: sizes[i],
|
||||
reverse=True)
|
||||
|
||||
for idx in large_to_small_indices:
|
||||
# Find GPU with minimum current load (by total size)
|
||||
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
||||
gpu_assignments[min_gpu].append(idx)
|
||||
gpu_loads[min_gpu] += sizes[idx]
|
||||
|
||||
# Create shuffle indices and counts
|
||||
shuffle_indices = list[int]()
|
||||
gpu_sample_counts = list[int]()
|
||||
for gpu_id in range(num_gpus):
|
||||
# GPU_0 = [1000] = [0]
|
||||
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
||||
# shuffle_indices = [0, 2, 1, 3]
|
||||
shuffle_indices.extend(gpu_assignments[gpu_id])
|
||||
# GPU_0 = [1]
|
||||
# GPU_1 = [3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
||||
|
||||
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model(
|
||||
vision_model: torch.nn.Module,
|
||||
pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]],
|
||||
*,
|
||||
rope_type: Literal["rope_3d", "rope_2d"],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Run a vision model with data parallelism (DP) sharding.
|
||||
The function will shard the input image tensor on the
|
||||
first dimension and run the vision model.
|
||||
This function is used to run the vision model with mrope.
|
||||
|
||||
Args:
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
pixel_values (torch.Tensor): Image/Video input tensor.
|
||||
grid_thw_list: List of grid dimensions for each image
|
||||
rope_type: Type of rope used in the vision model.
|
||||
Different rope types have different dimension to do ViT.
|
||||
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
|
||||
"rope_2d" for 2D rope (e.g., Kimi-VL)
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
|
||||
Example:
|
||||
```
|
||||
vision_model.out_hidden_size = 64
|
||||
vision_model.spatial_merge_size = 2
|
||||
pixel_values.shape = (1350, channel)
|
||||
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
||||
tp_size=2
|
||||
```
|
||||
|
||||
"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# GPU_0 tp_rank_local = 0
|
||||
# GPU_1 tp_rank_local = 1
|
||||
tp_rank_local = get_tensor_model_parallel_rank()
|
||||
|
||||
# patches_per_image = [1000, 100, 200, 50]
|
||||
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
||||
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
||||
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
||||
|
||||
# Get load balancing assignment with all metadata
|
||||
# image_to_tp_rank = [0, 2, 1, 3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
# grouped_pixel_values_len = [1000, 350]
|
||||
(image_to_tp_rank, gpu_sample_counts,
|
||||
grouped_pixel_values_len) = get_load_balance_assignment(
|
||||
patches_per_image, tp_size)
|
||||
|
||||
# cu_gpu_sample_counts = [0, 1, 4]
|
||||
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
||||
|
||||
# GPU_0 image_idxs_local = [0]
|
||||
# GPU_1 image_idxs_local = [2, 1, 3]
|
||||
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
||||
cum_gpu_sample_counts[tp_rank_local +
|
||||
1]]
|
||||
|
||||
# Get the pixel values for the local images based on the image_idxs_local
|
||||
if len(image_idxs_local) > 0:
|
||||
pixel_values_local = torch.cat([
|
||||
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
||||
for i in image_idxs_local
|
||||
])
|
||||
else:
|
||||
# Handle case where this rank has no images
|
||||
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
# embed_dim_reduction_factor = 2 * 2
|
||||
if rope_type == "rope_2d":
|
||||
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
|
||||
vision_model.merge_kernel_size[1])
|
||||
else:
|
||||
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
||||
vision_model.spatial_merge_size)
|
||||
|
||||
# Find the max length across all ranks
|
||||
# The output embedding of every DP rank has to be
|
||||
# padded to this length for tensor_model_parallel_all_gather
|
||||
# to work
|
||||
max_len_per_rank = max(
|
||||
grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
||||
|
||||
# Run the vision model on the local pixel_values_local
|
||||
if rope_type == "rope_2d":
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(
|
||||
pixel_values_local, torch.tensor(local_grid_thw_list))
|
||||
if isinstance(image_embeds_local, list):
|
||||
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
||||
else:
|
||||
out_dim = getattr(vision_model.config, "hidden_size", None)
|
||||
image_embeds_local = torch.empty(
|
||||
(0, embed_dim_reduction_factor, out_dim),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
else:
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(pixel_values_local,
|
||||
local_grid_thw_list)
|
||||
else:
|
||||
# Handle empty case
|
||||
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
# Pad the output based on max_len_per_rank
|
||||
# for tensor_model_parallel_all_gather to work
|
||||
current_len = image_embeds_local.shape[0]
|
||||
if current_len < max_len_per_rank:
|
||||
padding_size = max_len_per_rank - current_len
|
||||
if rope_type == "rope_2d":
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1],
|
||||
image_embeds_local.shape[2]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
else:
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
||||
dim=0)
|
||||
else:
|
||||
image_embeds_local_padded = image_embeds_local
|
||||
|
||||
# Do all_gather to collect embeddings from all ranks
|
||||
gathered_embeds = tensor_model_parallel_all_gather(
|
||||
image_embeds_local_padded, dim=0)
|
||||
|
||||
# Remove padding and reconstruct per-rank embeddings
|
||||
rank_embeddings = list[torch.Tensor]()
|
||||
for rank in range(tp_size):
|
||||
start_idx = rank * max_len_per_rank
|
||||
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
||||
embed_dim_reduction_factor)
|
||||
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
||||
|
||||
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
||||
for patch_size in patches_per_image]
|
||||
|
||||
# Reconstruct embeddings in the original order
|
||||
original_order_embeddings = [None] * len(grid_thw_list)
|
||||
current_idx = 0
|
||||
for rank in range(tp_size):
|
||||
count = gpu_sample_counts[rank]
|
||||
if count > 0:
|
||||
# Get images assigned to this rank in shuffled order
|
||||
# GPU_0 = image_idxs_local [0]
|
||||
# GPU_1 = image_idxs_local [2, 1, 3]
|
||||
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
||||
|
||||
rank_embed = rank_embeddings[rank]
|
||||
# Split rank embeddings back to individual images
|
||||
embed_start = 0
|
||||
for img_idx in rank_images:
|
||||
img_patches = patches_per_output_image[img_idx]
|
||||
original_order_embeddings[img_idx] = rank_embed[
|
||||
embed_start:embed_start + img_patches]
|
||||
embed_start += img_patches
|
||||
current_idx += count
|
||||
out_embeddings = tuple(embed for embed in original_order_embeddings
|
||||
if embed is not None)
|
||||
assert len(out_embeddings) == len(
|
||||
original_order_embeddings), "Found unassigned embeddings"
|
||||
return out_embeddings
|
||||
|
||||
@ -3,13 +3,11 @@
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import itertools
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
@ -21,9 +19,6 @@ from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
@ -33,12 +28,10 @@ from .video import VideoMediaIO
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict)
|
||||
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, MultiModalPlaceholderDict)
|
||||
else:
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalKwargsItems = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
@ -93,7 +86,7 @@ class MediaConnector:
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
@ -107,7 +100,7 @@ class MediaConnector:
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError("Cannot load local files without "
|
||||
@ -127,7 +120,7 @@ class MediaConnector:
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality(
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
will shard the input image tensor on the first dimension and run the vision
|
||||
model
|
||||
|
||||
Args:
|
||||
image_input (torch.Tensor): Image input tensor.
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
"""
|
||||
|
||||
num_chunks = image_input.shape[0]
|
||||
mp_world_size = get_tensor_model_parallel_world_size()
|
||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
image_input_per_rank = image_input_padded[rank *
|
||||
num_chunks_per_rank:(rank + 1) *
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
vision_embeddings = vision_embeddings.contiguous()
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
return vision_embeddings
|
||||
|
||||
|
||||
def get_load_balance_assignment(
|
||||
sizes: list[int],
|
||||
num_gpus: int = 2,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Generate load balancing assignment and metadata
|
||||
for distributing data across GPUs.
|
||||
The load is determined by the total image sizes,
|
||||
not the number of images.
|
||||
|
||||
Args:
|
||||
sizes: The size of each image
|
||||
num_gpus: Number of GPUs to balance across
|
||||
|
||||
Returns:
|
||||
shuffle_indices:
|
||||
Indices to reorder data for balanced loading
|
||||
gpu_sample_counts:
|
||||
Number of samples assigned to each GPU
|
||||
grouped_sizes_per_gpu:
|
||||
Total size assigned to each GPU
|
||||
|
||||
Example:
|
||||
```
|
||||
sizes = [1000, 100, 200, 50]
|
||||
num_gpus=2
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
n_samples = len(sizes)
|
||||
|
||||
# Handle edge cases
|
||||
if n_samples == 0:
|
||||
return [], [0] * num_gpus, [0] * num_gpus
|
||||
|
||||
# Use greedy algorithm - balance by total size, not sample count
|
||||
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
||||
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
||||
|
||||
# Sort indices by size (largest first for better load balancing)
|
||||
# sizes = [1000, 100, 200, 50]
|
||||
# large_to_small_indices = [0, 2, 1, 3]
|
||||
large_to_small_indices = sorted(range(n_samples),
|
||||
key=lambda i: sizes[i],
|
||||
reverse=True)
|
||||
|
||||
for idx in large_to_small_indices:
|
||||
# Find GPU with minimum current load (by total size)
|
||||
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
||||
gpu_assignments[min_gpu].append(idx)
|
||||
gpu_loads[min_gpu] += sizes[idx]
|
||||
|
||||
# Create shuffle indices and counts
|
||||
shuffle_indices = list[int]()
|
||||
gpu_sample_counts = list[int]()
|
||||
for gpu_id in range(num_gpus):
|
||||
# GPU_0 = [1000] = [0]
|
||||
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
||||
# shuffle_indices = [0, 2, 1, 3]
|
||||
shuffle_indices.extend(gpu_assignments[gpu_id])
|
||||
# GPU_0 = [1]
|
||||
# GPU_1 = [3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
||||
|
||||
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model(
|
||||
vision_model: torch.nn.Module,
|
||||
pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]],
|
||||
*,
|
||||
rope_type: Literal["rope_3d", "rope_2d"],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Run a vision model with data parallelism (DP) sharding.
|
||||
The function will shard the input image tensor on the
|
||||
first dimension and run the vision model.
|
||||
This function is used to run the vision model with mrope.
|
||||
|
||||
Args:
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
pixel_values (torch.Tensor): Image/Video input tensor.
|
||||
grid_thw_list: List of grid dimensions for each image
|
||||
rope_type: Type of rope used in the vision model.
|
||||
Different rope types have different dimension to do ViT.
|
||||
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
|
||||
"rope_2d" for 2D rope (e.g., Kimi-VL)
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
|
||||
Example:
|
||||
```
|
||||
vision_model.out_hidden_size = 64
|
||||
vision_model.spatial_merge_size = 2
|
||||
pixel_values.shape = (1350, channel)
|
||||
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
||||
tp_size=2
|
||||
```
|
||||
|
||||
"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# GPU_0 tp_rank_local = 0
|
||||
# GPU_1 tp_rank_local = 1
|
||||
tp_rank_local = get_tensor_model_parallel_rank()
|
||||
|
||||
# patches_per_image = [1000, 100, 200, 50]
|
||||
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
||||
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
||||
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
||||
|
||||
# Get load balancing assignment with all metadata
|
||||
# image_to_tp_rank = [0, 2, 1, 3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
# grouped_pixel_values_len = [1000, 350]
|
||||
(image_to_tp_rank, gpu_sample_counts,
|
||||
grouped_pixel_values_len) = get_load_balance_assignment(
|
||||
patches_per_image, tp_size)
|
||||
|
||||
# cu_gpu_sample_counts = [0, 1, 4]
|
||||
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
||||
|
||||
# GPU_0 image_idxs_local = [0]
|
||||
# GPU_1 image_idxs_local = [2, 1, 3]
|
||||
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
||||
cum_gpu_sample_counts[tp_rank_local +
|
||||
1]]
|
||||
|
||||
# Get the pixel values for the local images based on the image_idxs_local
|
||||
if len(image_idxs_local) > 0:
|
||||
pixel_values_local = torch.cat([
|
||||
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
||||
for i in image_idxs_local
|
||||
])
|
||||
else:
|
||||
# Handle case where this rank has no images
|
||||
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
# embed_dim_reduction_factor = 2 * 2
|
||||
if rope_type == "rope_2d":
|
||||
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
|
||||
vision_model.merge_kernel_size[1])
|
||||
else:
|
||||
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
||||
vision_model.spatial_merge_size)
|
||||
|
||||
# Find the max length across all ranks
|
||||
# The output embedding of every DP rank has to be
|
||||
# padded to this length for tensor_model_parallel_all_gather
|
||||
# to work
|
||||
max_len_per_rank = max(
|
||||
grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
||||
|
||||
# Run the vision model on the local pixel_values_local
|
||||
if rope_type == "rope_2d":
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(
|
||||
pixel_values_local, torch.tensor(local_grid_thw_list))
|
||||
if isinstance(image_embeds_local, list):
|
||||
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
||||
else:
|
||||
out_dim = getattr(vision_model.config, "hidden_size", None)
|
||||
image_embeds_local = torch.empty(
|
||||
(0, embed_dim_reduction_factor, out_dim),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
else:
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(pixel_values_local,
|
||||
local_grid_thw_list)
|
||||
else:
|
||||
# Handle empty case
|
||||
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
# Pad the output based on max_len_per_rank
|
||||
# for tensor_model_parallel_all_gather to work
|
||||
current_len = image_embeds_local.shape[0]
|
||||
if current_len < max_len_per_rank:
|
||||
padding_size = max_len_per_rank - current_len
|
||||
if rope_type == "rope_2d":
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1],
|
||||
image_embeds_local.shape[2]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
else:
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
||||
dim=0)
|
||||
else:
|
||||
image_embeds_local_padded = image_embeds_local
|
||||
|
||||
# Do all_gather to collect embeddings from all ranks
|
||||
gathered_embeds = tensor_model_parallel_all_gather(
|
||||
image_embeds_local_padded, dim=0)
|
||||
|
||||
# Remove padding and reconstruct per-rank embeddings
|
||||
rank_embeddings = list[torch.Tensor]()
|
||||
for rank in range(tp_size):
|
||||
start_idx = rank * max_len_per_rank
|
||||
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
||||
embed_dim_reduction_factor)
|
||||
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
||||
|
||||
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
||||
for patch_size in patches_per_image]
|
||||
|
||||
# Reconstruct embeddings in the original order
|
||||
original_order_embeddings = [None] * len(grid_thw_list)
|
||||
current_idx = 0
|
||||
for rank in range(tp_size):
|
||||
count = gpu_sample_counts[rank]
|
||||
if count > 0:
|
||||
# Get images assigned to this rank in shuffled order
|
||||
# GPU_0 = image_idxs_local [0]
|
||||
# GPU_1 = image_idxs_local [2, 1, 3]
|
||||
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
||||
|
||||
rank_embed = rank_embeddings[rank]
|
||||
# Split rank embeddings back to individual images
|
||||
embed_start = 0
|
||||
for img_idx in rank_images:
|
||||
img_patches = patches_per_output_image[img_idx]
|
||||
original_order_embeddings[img_idx] = rank_embed[
|
||||
embed_start:embed_start + img_patches]
|
||||
embed_start += img_patches
|
||||
current_idx += count
|
||||
out_embeddings = tuple(embed for embed in original_order_embeddings
|
||||
if embed is not None)
|
||||
assert len(out_embeddings) == len(
|
||||
original_order_embeddings), "Found unassigned embeddings"
|
||||
return out_embeddings
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user