From 7b43db210c34023cb1e8bcec1d3c49537556856f Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Thu, 18 Dec 2025 20:17:27 -0600 Subject: [PATCH] [ROCm][CI][Bugfix] Multi-Modal Model Support Fixes and Attention Backend Improvements (#30270) Signed-off-by: Andreas Karatzas --- .buildkite/test-amd.yaml | 21 +++++++--- .../multimodal/{generation => }/conftest.py | 9 ++-- .../multimodal/generation/test_common.py | 32 +++++++++++++- .../generation/test_granite_speech.py | 2 +- tests/models/multimodal/pooling/conftest.py | 18 -------- .../models/transformers/multimodal.py | 32 +++++++++++++- vllm/platforms/rocm.py | 42 ++++++++++++------- 7 files changed, 109 insertions(+), 47 deletions(-) rename tests/models/multimodal/{generation => }/conftest.py (79%) delete mode 100644 tests/models/multimodal/pooling/conftest.py diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 6e20ff3bf38d9..9a770869b1d17 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -964,7 +964,7 @@ steps: - pytest -v -s models/multimodal/processing - label: Multi-Modal Models Test (Standard) # 60min - timeout_in_minutes: 80 + timeout_in_minutes: 100 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -973,13 +973,15 @@ steps: - vllm/ - tests/models/multimodal commands: + - export MIOPEN_DEBUG_CONV_DIRECT=0 + - export MIOPEN_DEBUG_CONV_GEMM=0 - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work -- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min - timeout_in_minutes: 180 +- label: Multi-Modal Accuracy Eval (Small Models) # 5min + timeout_in_minutes: 10 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -989,7 +991,9 @@ steps: - vllm/inputs/ - vllm/v1/core/ commands: - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 + - export MIOPEN_DEBUG_CONV_DIRECT=0 + - export MIOPEN_DEBUG_CONV_GEMM=0 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt - label: Multi-Modal Models Test (Extended) 1 # 60min timeout_in_minutes: 120 @@ -1001,10 +1005,13 @@ steps: - vllm/ - tests/models/multimodal commands: + - export MIOPEN_DEBUG_CONV_DIRECT=0 + - export MIOPEN_DEBUG_CONV_GEMM=0 - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing -- label: Multi-Modal Models Test (Extended) 2 +- label: Multi-Modal Models Test (Extended) 2 #60min + timeout_in_minutes: 120 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -1013,6 +1020,8 @@ steps: - vllm/ - tests/models/multimodal commands: + - export MIOPEN_DEBUG_CONV_DIRECT=0 + - export MIOPEN_DEBUG_CONV_GEMM=0 - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' @@ -1026,6 +1035,8 @@ steps: - vllm/ - tests/models/multimodal commands: + - export MIOPEN_DEBUG_CONV_DIRECT=0 + - export MIOPEN_DEBUG_CONV_GEMM=0 - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' diff --git a/tests/models/multimodal/generation/conftest.py b/tests/models/multimodal/conftest.py similarity index 79% rename from tests/models/multimodal/generation/conftest.py rename to tests/models/multimodal/conftest.py index 26f8586742cea..4243298cdc896 100644 --- a/tests/models/multimodal/generation/conftest.py +++ b/tests/models/multimodal/conftest.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Pytest configuration for vLLM tests.""" +"""Pytest configuration for vLLM multimodal tests.""" import warnings @@ -9,16 +9,13 @@ import torch from vllm.platforms import current_platform -def pytest_configure(config): - """Disable Flash/MemEfficient SDP on ROCm to avoid HF - Transformers accuracy issues. - """ +def pytest_collection_modifyitems(config, items): + """Configure ROCm-specific settings based on collected tests.""" if not current_platform.is_rocm(): return skip_patterns = ["test_granite_speech.py"] if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns): - # Skip disabling SDP for Granite Speech tests on ROCm return # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index c5a0b6748f797..6640e1ff9474d 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -173,6 +173,13 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner, + vllm_runner_kwargs={ + "attention_config": { + "backend": "ROCM_AITER_FA", + }, + } + if current_platform.is_rocm() + else None, image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[ pytest.mark.core_model, @@ -253,8 +260,19 @@ VLM_TEST_SETTINGS = { image_size_factors=[(0.25, 0.2, 0.15)], vllm_runner_kwargs={ "model_impl": "transformers", + # TODO: [ROCm] Revert this once issue #30167 is resolved + **( + { + "mm_processor_kwargs": { + "min_pixels": 256 * 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, + } + if current_platform.is_rocm() + else {} + ), }, - marks=[large_gpu_mark(min_gb=32)], + marks=[large_gpu_mark(min_gb=80 if current_platform.is_rocm() else 32)], ), #### Extended model tests "aria": VLMTestInfo( @@ -645,7 +663,17 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.minimax_vl_01_hf_output, patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner, auto_cls=AutoModelForImageTextToText, - marks=[large_gpu_mark(min_gb=80)], + marks=[ + large_gpu_mark(min_gb=80), + # TODO: [ROCm] Fix pickle issue with ROCm spawn and tp>1 + pytest.mark.skipif( + current_platform.is_rocm(), + reason=( + "ROCm: Model too large for single GPU; " + "multi-GPU blocked by HF _LazyConfigMapping pickle issue with spawn" + ), + ), + ], ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index 489743c5a29b3..1519a50c1a0c3 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -39,7 +39,7 @@ models = [MODEL_NAME] def granite_speech_attention_config(): """Return attention config for Granite Speech tests on ROCm.""" if current_platform.is_rocm(): - return {"backend": "TRITON_ATTN"} + return {"backend": "ROCM_AITER_FA"} return None diff --git a/tests/models/multimodal/pooling/conftest.py b/tests/models/multimodal/pooling/conftest.py deleted file mode 100644 index 401bc39b4b109..0000000000000 --- a/tests/models/multimodal/pooling/conftest.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Pytest configuration for vLLM pooling tests.""" - -import pytest - -from vllm.platforms import current_platform - - -@pytest.fixture -def siglip_attention_config(): - """Return attention config for SigLIP tests on ROCm. - - On ROCm, SigLIP tests require FLEX_ATTENTION backend. - """ - if current_platform.is_rocm(): - return {"backend": "FLEX_ATTENTION"} - return None diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 9d77dee2810c3..fcf9a0d077abe 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING import torch from vllm.config.utils import getattr_iter +from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal import MultiModalKwargsItems @@ -36,6 +37,7 @@ from vllm.multimodal.inputs import ( from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors if TYPE_CHECKING: @@ -52,6 +54,8 @@ DYNAMIC_ARG_DIMS = { "inputs_embeds": 0, } +logger = init_logger(__name__) + class MultiModalProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self): @@ -345,8 +349,29 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): num_image_patches = kwargs.pop("num_image_patches") kwargs.pop("token_type_ids", None) # used only in `forward` + if pixel_values is not None: - vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) + # ROCm: Force math SDP backend for vision encoder to avoid accuracy issues + # with flash_sdp and mem_efficient_sdp + if current_platform.is_rocm(): + # TODO: [ROCm] Fix accuracy issues with flash backend + logger.debug( + "ROCm platform detected. Forcing math SDP backend " + "for vision encoder. Currently ROCm platform has " + "accuracy issues with `flash_sdp` and" + "`mem_efficient_sdp` backends. See issue: " + "https://github.com/vllm-project/vllm/issues/30167" + ) + with torch.nn.attention.sdpa_kernel( + backends=[torch.nn.attention.SDPBackend.MATH] + ): + vision_embeddings = self.model.get_image_features( + pixel_values, **kwargs + ) + else: + vision_embeddings = self.model.get_image_features( + pixel_values, **kwargs + ) if isinstance(vision_embeddings, torch.Tensor): if vision_embeddings.ndim == 2: @@ -364,6 +389,11 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ] return vision_embeddings + else: + logger.debug( + "No pixel values or image embeddings provided for multimodal embedding." + ) + return None def get_mrope_input_positions( self, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c237f7cf887c1..5892639eba406 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.torch_utils import cuda_device_count_stateless @@ -204,7 +205,7 @@ class RocmPlatform(Platform): assert block_size == 1, ( "Sparse MLA backend on ROCm only supports block size 1 for now." ) - logger.info_once("Using Sparse MLA backend on V1 engine.") + logger.info_once("Using Sparse MLA backend.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() if attn_selector_config.use_mla: @@ -239,16 +240,16 @@ class RocmPlatform(Platform): return AttentionBackendEnum.FLEX_ATTENTION.get_path() if selected_backend == AttentionBackendEnum.TRITON_ATTN: - logger.info("Using Triton Attention backend on V1 engine.") + logger.info("Using Triton Attention backend.") return AttentionBackendEnum.TRITON_ATTN.get_path() if selected_backend == AttentionBackendEnum.ROCM_ATTN: - logger.info("Using Rocm Attention backend on V1 engine.") + logger.info("Using Rocm Attention backend.") return AttentionBackendEnum.ROCM_ATTN.get_path() if selected_backend == AttentionBackendEnum.ROCM_AITER_FA: if on_gfx9(): - logger.info("Using Aiter Flash Attention backend on V1 engine.") + logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() else: raise ValueError( @@ -257,25 +258,25 @@ class RocmPlatform(Platform): ) if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend on V1 engine.") + logger.info("Using Aiter Unified Attention backend.") return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() # Handle automatic backend selection based on environment variables if selected_backend is None: # Priority 1: Check for AITER Unified Attention (must check before MHA) if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: - logger.info("Using Aiter Unified Attention backend on V1 engine.") + logger.info("Using Aiter Unified Attention backend.") return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() # Priority 2: Check for AITER MHA (Flash Attention) # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - logger.info("Using Aiter Flash Attention backend on V1 engine.") + logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() # Priority 3: Check for ROCM_ATTN (prefill-decode split) if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: - logger.info("Using Rocm Attention backend on V1 engine.") + logger.info("Using Rocm Attention backend.") return AttentionBackendEnum.ROCM_ATTN.get_path() # Priority 4: Check for AITER enabled without specific flags @@ -285,11 +286,19 @@ class RocmPlatform(Platform): and on_gfx9() and envs.VLLM_ROCM_USE_AITER_MHA is not False ): - logger.info("Using Aiter Flash Attention backend on V1 engine.") + logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() + # Priority 5: If model is Encoder-only self-attention type + if ( + attn_selector_config.attn_type is not None + and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY + ): + logger.info("Using FlexAttention backend.") + return AttentionBackendEnum.FLEX_ATTENTION.get_path() + # Default: Triton Unified Attention - logger.info("Using Triton Attention backend on V1 engine.") + logger.info("Using Triton Attention backend.") return AttentionBackendEnum.TRITON_ATTN.get_path() raise RuntimeError( @@ -324,14 +333,19 @@ class RocmPlatform(Platform): from vllm._aiter_ops import rocm_aiter_ops - if rocm_aiter_ops.is_mha_enabled(): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. + if rocm_aiter_ops.is_enabled(): + logger.info_once("Using AITER Flash Attention backend for ViT model.") return AttentionBackendEnum.ROCM_AITER_FA - if on_gfx9() and find_spec("flash_attn") is not None: + if ( + on_gfx9() + and find_spec("flash_attn") is not None + and (dtype == torch.float16 or dtype == torch.bfloat16) + ): + logger.info_once("Using Flash Attention backend for ViT model.") return AttentionBackendEnum.FLASH_ATTN + logger.info_once("Using Torch SDPA backend for ViT model.") return AttentionBackendEnum.TORCH_SDPA @classmethod