From e96a6a6dca930d00902852ea6937a214a584b89b Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Thu, 4 Dec 2025 05:00:16 -0600 Subject: [PATCH] [ROCm][CI][Bugfix] Fixing the `Multi-Modal Models Test (Extended) 1` group (#30013) Signed-off-by: Andreas Karatzas --- .buildkite/test-amd.yaml | 6 ++- .../models/multimodal/generation/conftest.py | 16 +++++++ .../multimodal/generation/test_common.py | 12 ++++- .../generation/test_granite_speech.py | 15 ++++++- .../multimodal/generation/test_pixtral.py | 10 +++++ .../generation/vlm_utils/custom_inputs.py | 2 +- .../generation/vlm_utils/model_utils.py | 45 ++++++++++++++++++- tests/models/multimodal/pooling/conftest.py | 24 ++++++++++ tests/models/registry.py | 4 ++ vllm/v1/attention/backends/flex_attention.py | 14 +++++- 10 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 tests/models/multimodal/pooling/conftest.py diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ee4fdebae5675..022b6ea236d54 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -987,7 +987,8 @@ steps: commands: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 -- label: Multi-Modal Models Test (Extended) 1 +- label: Multi-Modal Models Test (Extended) 1 # 60min + timeout_in_minutes: 120 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -1011,7 +1012,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' -- label: Multi-Modal Models Test (Extended) 3 +- label: Multi-Modal Models Test (Extended) 3 # 75min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking diff --git a/tests/models/multimodal/generation/conftest.py b/tests/models/multimodal/generation/conftest.py index ee3ecdb10fdb8..26f8586742cea 100644 --- a/tests/models/multimodal/generation/conftest.py +++ b/tests/models/multimodal/generation/conftest.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Pytest configuration for vLLM tests.""" +import warnings + import torch from vllm.platforms import current_platform @@ -14,6 +16,20 @@ def pytest_configure(config): if not current_platform.is_rocm(): return + skip_patterns = ["test_granite_speech.py"] + if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns): + # Skip disabling SDP for Granite Speech tests on ROCm + return + + # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers + # accuracy issues + # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) + warnings.warn( + "ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp " + "to avoid HuggingFace Transformers accuracy issues", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 0eaf7198f91b7..f896126a49089 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -403,12 +403,13 @@ VLM_TEST_SETTINGS = { # So, we need to reduce the number of tokens for the test to pass. max_tokens=8, num_logprobs=10, + auto_cls=AutoModelForCausalLM, marks=[large_gpu_mark(min_gb=32)], ), "glm4_1v": VLMTestInfo( models=["zai-org/GLM-4.1V-9B-Thinking"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", max_model_len=2048, @@ -423,6 +424,7 @@ VLM_TEST_SETTINGS = { models=["zai-org/GLM-4.1V-9B-Thinking"], # GLM4.1V require include video metadata for input test_type=VLMTestType.CUSTOM_INPUTS, + prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501 max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -737,7 +739,13 @@ VLM_TEST_SETTINGS = { max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - marks=[large_gpu_mark(min_gb=48)], + marks=[ + large_gpu_mark(min_gb=48), + pytest.mark.skipif( + current_platform.is_rocm(), + reason="Model produces a vector of output in HF on ROCm", + ), + ], ), "qwen_vl": VLMTestInfo( models=["Qwen/Qwen-VL"], diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index e39dfc888779e..f528a993f8551 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS @@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME models = [MODEL_NAME] +@pytest.fixture(autouse=True) +def set_attention_backend_for_rocm(monkeypatch): + if current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + + def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -111,8 +118,12 @@ def run_test( @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize( + "dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"] +) +@pytest.mark.parametrize( + "max_model_len", [512] if current_platform.is_rocm() else [2048] +) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models( diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 3cad2c43d5623..375099f4365ac 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -15,6 +15,7 @@ from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins +from vllm.platforms import current_platform from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close @@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: def test_chat( vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server ) -> None: + if ( + model == MISTRAL_SMALL_3_1_ID + and max_model_len == 65536 + and current_platform.is_rocm() + ): + pytest.skip( + "OOM on ROCm: 24B model with 65536 context length exceeds GPU memory" + ) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( model, diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 8c9c390911bdc..84109233685bb 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v(): metadata = VIDEO_ASSETS[0].metadata question = "Describe the video." video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" - formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] video_input = [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 87cd5c3cd3554..b2c62fbd119cc 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -25,6 +25,7 @@ from transformers import ( from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs +from vllm.platforms import current_platform from vllm.utils.collection_utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" + if current_platform.is_rocm(): + import types + + config = hf_model.model.config + if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = config.num_layers + config.output_hidden_states = True + + def patched_prepare_cache( + self, generation_config, model_kwargs, *args, **kwargs + ): + model_kwargs["past_key_values"] = None + model_kwargs["use_cache"] = False + return model_kwargs + + hf_model.model._prepare_cache_for_generation = types.MethodType( + patched_prepare_cache, hf_model.model + ) + original_generate = hf_model.model.generate + + def patched_generate(*args, **kwargs): + kwargs["output_hidden_states"] = True + kwargs["return_dict_in_generate"] = True + return original_generate(*args, **kwargs) + + hf_model.model.generate = patched_generate + original_forward = hf_model.model.forward + + def patched_forward(*args, **kwargs): + kwargs["output_hidden_states"] = True + return original_forward(*args, **kwargs) + + hf_model.model.forward = patched_forward + hf_processor = hf_model.processor def processor(*args, text="", images=None, **kwargs): @@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: if videos is not None and is_list_of(videos, tuple): # If videos is a list of tuples, we assume each tuple contains # (video_array, metadata) as in the case of GLM4.1V. - video_metadata = [[VideoMetadata(**video[1])] for video in videos] + # Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg + video_metadata = [ + [ + VideoMetadata( + **{k: v for k, v in video[1].items() if k != "do_sample_frames"} + ) + ] + for video in videos + ] videos = [[video[0]] for video in videos] else: video_metadata = None diff --git a/tests/models/multimodal/pooling/conftest.py b/tests/models/multimodal/pooling/conftest.py new file mode 100644 index 0000000000000..c5f40cb42ca2a --- /dev/null +++ b/tests/models/multimodal/pooling/conftest.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pytest configuration for vLLM pooling tests.""" + +import os +import warnings + +from vllm.platforms import current_platform + + +def pytest_collection_modifyitems(config, items): + """Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" + if not current_platform.is_rocm(): + return + + siglip_tests = [item for item in items if "test_siglip" in item.nodeid] + + if siglip_tests: + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + warnings.warn( + "ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6b1d24b1c99b5..bf88bac209808 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -667,6 +667,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "moonshotai/Kimi-VL-A3B-Instruct", extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, trust_remote_code=True, + max_transformers_version="4.53.3", + transformers_version_reason="HF model uses deprecated transformers API " + "(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: " + "https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31", ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( "lightonai/LightOnOCR-1B", diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index fe92f6570501c..a2a6eeeb16b24 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -31,6 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( @@ -927,7 +928,18 @@ def get_kernel_options( if torch.cuda.is_available(): device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin + # ROCm doesn't expose shared_memory_per_block_optin attribute + # AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup + if hasattr(device_props, "shared_memory_per_block_optin"): + max_shared_memory = device_props.shared_memory_per_block_optin + elif current_platform.is_rocm(): + # ROCm fallback: use 64KB + max_shared_memory = 65536 + else: + raise RuntimeError( + "Unable to determine shared memory size on this hardware." + ) + if max_shared_memory < 144 * 1024: block_m_candidate = ensure_divisible( max(1, block_m_candidate // 2), block_m