From f168b85725202915b5719c62b46d310a608b13dd Mon Sep 17 00:00:00 2001 From: Siqi Yan Date: Fri, 6 Jun 2025 01:24:02 -0700 Subject: [PATCH] Unit Test for run_dp_sharded_vision_model (#19103) Signed-off-by: Siqi Yan Co-authored-by: Siqi Yan --- tests/multimodal/test_utils.py | 98 +++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index c8a54482214d4..5ac0a90f50473 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,12 +9,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional import numpy as np import pytest +import torch +import torch.multiprocessing as mp from PIL import Image, ImageChops +from tests.utils import multi_gpu_test +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, - merge_and_sort_multimodal_metadata) + merge_and_sort_multimodal_metadata, + run_dp_sharded_vision_model) +from vllm.platforms import current_platform +from vllm.utils import get_open_port, update_environment_variables if TYPE_CHECKING: from vllm.multimodal.hasher import MultiModalHashDict @@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving(): assert modalities == expected_modalities assert ranges == expected_ranges assert hashes == expected_hashes + + +class SimpleLinearModel(torch.nn.Module): + """A simple linear vision model for testing.""" + + def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor): + # Flatten the input and apply linear transformation + x = self.flatten(x) + return self.linear(x) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 4, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, + batch_size: int, master_port: int): + """ + Test that run_dp_sharded_vision_model produces the same results as + calling the model directly. + """ + + # Set random seed for reproducibility + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create a test input tensor + image_input = torch.randn(batch_size, 3, 224, 224) + + # Create a simple linear model + vision_model = SimpleLinearModel() + + # Run the model directly on the full input + with torch.inference_mode(): + direct_output = vision_model(image_input) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_vision_model(image_input, vision_model) + + # Check that the world size is setup correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)