mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 11:03:04 +08:00
[CI/Build] Add TP test for vision models (#5892)
This commit is contained in:
parent
8dbfcd35bf
commit
99397da534
@ -44,6 +44,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -52,10 +53,14 @@ steps:
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
51
tests/distributed/test_multimodal_broadcast.py
Normal file
51
tests/distributed/test_multimodal_broadcast.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
The second test will hang if more than one test is run per command, so we need
|
||||
to run the tests one by one. The solution is to pass arguments (model name) by
|
||||
environment variables.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \
|
||||
test_multimodal_broadcast.py
|
||||
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \
|
||||
test_multimodal_broadcast.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
model = os.environ["TEST_DIST_MODEL"]
|
||||
|
||||
if model.startswith("llava-hf/llava"):
|
||||
from ..models.test_llava import model_and_vl_config, run_test
|
||||
elif model.startswith("microsoft/Phi-3-vision"):
|
||||
from ..models.test_phi3v import model_and_vl_config, run_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, image_assets,
|
||||
tensor_parallel_size: int, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
pytest.skip(
|
||||
f"Need at least {tensor_parallel_size} GPUs to run the test.")
|
||||
|
||||
distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND")
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config=model_and_vl_config[0],
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
@ -1,11 +1,11 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
@ -65,12 +65,17 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
return hf_output_ids, hf_output_str
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -96,6 +101,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
|
||||
with vllm_runner(model_id,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
@ -110,3 +117,19 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -6,7 +6,7 @@ from transformers import AutoTokenizer
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
@ -73,17 +73,17 @@ if is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
# Since we use _attn_implementation="eager" for hf_runner, here is
|
||||
# numeric difference for longer context and test can't pass
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -116,7 +116,9 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
with vllm_runner(model_id,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
@ -130,3 +132,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
# Since we use _attn_implementation="eager" for hf_runner, here is
|
||||
# numeric difference for longer context and test can't pass
|
||||
@pytest.mark.xfail(
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -268,6 +268,7 @@ class ShmRingBufferIO:
|
||||
else:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
|
||||
@ -194,7 +194,7 @@ class GroupCoordinator:
|
||||
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
|
||||
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
|
||||
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||
self.cpu_group, 1 << 20, 6)
|
||||
self.cpu_group, 1 << 22, 6)
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
@ -690,6 +690,8 @@ class GroupCoordinator:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.shm_broadcaster is not None:
|
||||
self.shm_broadcaster = None
|
||||
|
||||
|
||||
_WORLD: Optional[GroupCoordinator] = None
|
||||
|
||||
@ -219,7 +219,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values.to(vision_tower.device),
|
||||
image_features = vision_tower(pixel_values,
|
||||
self.config.vision_feature_layer)
|
||||
|
||||
return self._select_image_features(
|
||||
|
||||
@ -301,7 +301,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values.to(vision_tower.device),
|
||||
image_features = vision_tower(pixel_values,
|
||||
self.config.vision_feature_layer)
|
||||
|
||||
return self._select_image_features(
|
||||
|
||||
@ -157,7 +157,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
select = False
|
||||
|
||||
target_device = self.img_projection[0].bias.device
|
||||
target_dtype = self.img_projection[0].bias.dtype
|
||||
|
||||
if len(positions.tolist()) > 0:
|
||||
@ -231,7 +230,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
img_set_tensor = []
|
||||
for _output_img in output_imgs:
|
||||
img_feature_proj = self.img_projection(
|
||||
_output_img.to(target_device, target_dtype))
|
||||
_output_img.to(target_dtype))
|
||||
img_set_tensor.append(img_feature_proj)
|
||||
select = True
|
||||
|
||||
@ -245,7 +244,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
hidden_states[positions[idx, 0],
|
||||
positions[idx, 1]:positions[idx, 1] +
|
||||
cnt] = (img_set_tensor[i].to(
|
||||
hidden_states.device, hidden_states.dtype))
|
||||
hidden_states.dtype))
|
||||
idx += cnt
|
||||
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user