[CI/Build] Add TP test for vision models (#5892)

This commit is contained in:
Cyrus Leung 2024-06-29 23:45:54 +08:00 committed by GitHub
parent 8dbfcd35bf
commit 99397da534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 131 additions and 27 deletions

View File

@ -44,6 +44,7 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- bash ../.buildkite/download-images.sh
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
@ -52,10 +53,14 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

View File

@ -0,0 +1,51 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
The second test will hang if more than one test is run per command, so we need
to run the tests one by one. The solution is to pass arguments (model name) by
environment variables.
Run:
```sh
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \
test_multimodal_broadcast.py
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \
test_multimodal_broadcast.py
```
"""
import os
import pytest
from vllm.utils import cuda_device_count_stateless
model = os.environ["TEST_DIST_MODEL"]
if model.startswith("llava-hf/llava"):
from ..models.test_llava import model_and_vl_config, run_test
elif model.startswith("microsoft/Phi-3-vision"):
from ..models.test_phi3v import model_and_vl_config, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets,
tensor_parallel_size: int, dtype: str,
max_tokens: int) -> None:
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(
f"Need at least {tensor_parallel_size} GPUs to run the test.")
distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND")
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config=model_and_vl_config[0],
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)

View File

@ -1,11 +1,11 @@
from typing import List, Tuple
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from ..conftest import IMAGE_ASSETS
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
pytestmark = pytest.mark.vlm
@ -65,12 +65,17 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
return hf_output_ids, hf_output_str
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
@ -96,6 +101,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
with vllm_runner(model_id,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
@ -110,3 +117,19 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
@ -6,7 +6,7 @@ from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.utils import is_cpu
from ..conftest import IMAGE_ASSETS
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
pytestmark = pytest.mark.vlm
@ -73,17 +73,17 @@ if is_cpu():
target_dtype = "bfloat16"
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
@ -116,7 +116,9 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
@ -130,3 +132,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)

View File

@ -268,6 +268,7 @@ class ShmRingBufferIO:
else:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,

View File

@ -194,7 +194,7 @@ class GroupCoordinator:
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
self.cpu_group, 1 << 20, 6)
self.cpu_group, 1 << 22, 6)
@property
def first_rank(self):
@ -690,6 +690,8 @@ class GroupCoordinator:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.shm_broadcaster is not None:
self.shm_broadcaster = None
_WORLD: Optional[GroupCoordinator] = None

View File

@ -219,7 +219,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
return self._select_image_features(

View File

@ -301,7 +301,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
return self._select_image_features(

View File

@ -157,7 +157,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
select = False
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
if len(positions.tolist()) > 0:
@ -231,7 +230,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_device, target_dtype))
_output_img.to(target_dtype))
img_set_tensor.append(img_feature_proj)
select = True
@ -245,7 +244,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.device, hidden_states.dtype))
hidden_states.dtype))
idx += cnt
return hidden_states.squeeze(0)