mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:24:30 +08:00
[ROCm][CI][Bugfix] Fixing the Multi-Modal Models Test (Extended) 1 group (#30013)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
parent
6366c098d7
commit
e96a6a6dca
@ -987,7 +987,8 @@ steps:
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||
timeout_in_minutes: 120
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -1011,7 +1012,8 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
- label: Multi-Modal Models Test (Extended) 3 # 75min
|
||||
timeout_in_minutes: 150
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM tests."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@ -14,6 +16,20 @@ def pytest_configure(config):
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
skip_patterns = ["test_granite_speech.py"]
|
||||
if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
|
||||
# Skip disabling SDP for Granite Speech tests on ROCm
|
||||
return
|
||||
|
||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||
# accuracy issues
|
||||
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
warnings.warn(
|
||||
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
|
||||
"to avoid HuggingFace Transformers accuracy issues",
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
@ -403,12 +403,13 @@ VLM_TEST_SETTINGS = {
|
||||
# So, we need to reduce the number of tokens for the test to pass.
|
||||
max_tokens=8,
|
||||
num_logprobs=10,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"glm4_1v": VLMTestInfo(
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>",
|
||||
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
|
||||
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
|
||||
max_model_len=2048,
|
||||
@ -423,6 +424,7 @@ VLM_TEST_SETTINGS = {
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
# GLM4.1V require include video metadata for input
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
@ -737,7 +739,13 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
marks=[
|
||||
large_gpu_mark(min_gb=48),
|
||||
pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Model produces a vector of <UNK> output in HF on ROCm",
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen-VL"],
|
||||
|
||||
@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME
|
||||
models = [MODEL_NAME]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_attention_backend_for_rocm(monkeypatch):
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
@ -111,8 +118,12 @@ def run_test(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_model_len", [2048])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"max_model_len", [512] if current_platform.is_rocm() else [2048]
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
|
||||
@ -15,6 +15,7 @@ from transformers import AutoProcessor
|
||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm.logprobs import Logprob, SampleLogprobs
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
|
||||
def test_chat(
|
||||
vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server
|
||||
) -> None:
|
||||
if (
|
||||
model == MISTRAL_SMALL_3_1_ID
|
||||
and max_model_len == 65536
|
||||
and current_platform.is_rocm()
|
||||
):
|
||||
pytest.skip(
|
||||
"OOM on ROCm: 24B model with 65536 context length exceeds GPU memory"
|
||||
)
|
||||
|
||||
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model])
|
||||
with vllm_runner(
|
||||
model,
|
||||
|
||||
@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v():
|
||||
metadata = VIDEO_ASSETS[0].metadata
|
||||
question = "Describe the video."
|
||||
video_prompt = "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n"
|
||||
formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n"
|
||||
|
||||
scales = [0.1, 0.2, 0.25]
|
||||
video_input = [
|
||||
|
||||
@ -25,6 +25,7 @@ from transformers import (
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||
@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut
|
||||
|
||||
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
|
||||
if current_platform.is_rocm():
|
||||
import types
|
||||
|
||||
config = hf_model.model.config
|
||||
if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"):
|
||||
config.num_hidden_layers = config.num_layers
|
||||
config.output_hidden_states = True
|
||||
|
||||
def patched_prepare_cache(
|
||||
self, generation_config, model_kwargs, *args, **kwargs
|
||||
):
|
||||
model_kwargs["past_key_values"] = None
|
||||
model_kwargs["use_cache"] = False
|
||||
return model_kwargs
|
||||
|
||||
hf_model.model._prepare_cache_for_generation = types.MethodType(
|
||||
patched_prepare_cache, hf_model.model
|
||||
)
|
||||
original_generate = hf_model.model.generate
|
||||
|
||||
def patched_generate(*args, **kwargs):
|
||||
kwargs["output_hidden_states"] = True
|
||||
kwargs["return_dict_in_generate"] = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
hf_model.model.generate = patched_generate
|
||||
original_forward = hf_model.model.forward
|
||||
|
||||
def patched_forward(*args, **kwargs):
|
||||
kwargs["output_hidden_states"] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
hf_model.model.forward = patched_forward
|
||||
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
def processor(*args, text="", images=None, **kwargs):
|
||||
@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
if videos is not None and is_list_of(videos, tuple):
|
||||
# If videos is a list of tuples, we assume each tuple contains
|
||||
# (video_array, metadata) as in the case of GLM4.1V.
|
||||
video_metadata = [[VideoMetadata(**video[1])] for video in videos]
|
||||
# Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg
|
||||
video_metadata = [
|
||||
[
|
||||
VideoMetadata(
|
||||
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
|
||||
)
|
||||
]
|
||||
for video in videos
|
||||
]
|
||||
videos = [[video[0]] for video in videos]
|
||||
else:
|
||||
video_metadata = None
|
||||
|
||||
24
tests/models/multimodal/pooling/conftest.py
Normal file
24
tests/models/multimodal/pooling/conftest.py
Normal file
@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM pooling tests."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm."""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
siglip_tests = [item for item in items if "test_siglip" in item.nodeid]
|
||||
|
||||
if siglip_tests:
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
warnings.warn(
|
||||
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
@ -667,6 +667,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.53.3",
|
||||
transformers_version_reason="HF model uses deprecated transformers API "
|
||||
"(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: "
|
||||
"https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31",
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B",
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -927,7 +928,18 @@ def get_kernel_options(
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device_props = torch.cuda.get_device_properties()
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
# ROCm doesn't expose shared_memory_per_block_optin attribute
|
||||
# AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup
|
||||
if hasattr(device_props, "shared_memory_per_block_optin"):
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
elif current_platform.is_rocm():
|
||||
# ROCm fallback: use 64KB
|
||||
max_shared_memory = 65536
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to determine shared memory size on this hardware."
|
||||
)
|
||||
|
||||
if max_shared_memory < 144 * 1024:
|
||||
block_m_candidate = ensure_divisible(
|
||||
max(1, block_m_candidate // 2), block_m
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user