mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 03:45:15 +08:00
[ROCm][CI][Bugfix] Multi-Modal Model Support Fixes and Attention Backend Improvements (#30270)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
parent
6a09612b2e
commit
7b43db210c
@ -964,7 +964,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 60min
|
||||
timeout_in_minutes: 80
|
||||
timeout_in_minutes: 100
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -973,13 +973,15 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min
|
||||
timeout_in_minutes: 180
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 5min
|
||||
timeout_in_minutes: 10
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -989,7 +991,9 @@ steps:
|
||||
- vllm/inputs/
|
||||
- vllm/v1/core/
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||
timeout_in_minutes: 120
|
||||
@ -1001,10 +1005,13 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
- label: Multi-Modal Models Test (Extended) 2 #60min
|
||||
timeout_in_minutes: 120
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -1013,6 +1020,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
@ -1026,6 +1035,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- export MIOPEN_DEBUG_CONV_DIRECT=0
|
||||
- export MIOPEN_DEBUG_CONV_GEMM=0
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM tests."""
|
||||
"""Pytest configuration for vLLM multimodal tests."""
|
||||
|
||||
import warnings
|
||||
|
||||
@ -9,16 +9,13 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Disable Flash/MemEfficient SDP on ROCm to avoid HF
|
||||
Transformers accuracy issues.
|
||||
"""
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Configure ROCm-specific settings based on collected tests."""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
skip_patterns = ["test_granite_speech.py"]
|
||||
if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
|
||||
# Skip disabling SDP for Granite Speech tests on ROCm
|
||||
return
|
||||
|
||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||
@ -173,6 +173,13 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
|
||||
vllm_runner_kwargs={
|
||||
"attention_config": {
|
||||
"backend": "ROCM_AITER_FA",
|
||||
},
|
||||
}
|
||||
if current_platform.is_rocm()
|
||||
else None,
|
||||
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[
|
||||
pytest.mark.core_model,
|
||||
@ -253,8 +260,19 @@ VLM_TEST_SETTINGS = {
|
||||
image_size_factors=[(0.25, 0.2, 0.15)],
|
||||
vllm_runner_kwargs={
|
||||
"model_impl": "transformers",
|
||||
# TODO: [ROCm] Revert this once issue #30167 is resolved
|
||||
**(
|
||||
{
|
||||
"mm_processor_kwargs": {
|
||||
"min_pixels": 256 * 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
}
|
||||
if current_platform.is_rocm()
|
||||
else {}
|
||||
),
|
||||
},
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
marks=[large_gpu_mark(min_gb=80 if current_platform.is_rocm() else 32)],
|
||||
),
|
||||
#### Extended model tests
|
||||
"aria": VLMTestInfo(
|
||||
@ -645,7 +663,17 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minimax_vl_01_hf_output,
|
||||
patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
marks=[large_gpu_mark(min_gb=80)],
|
||||
marks=[
|
||||
large_gpu_mark(min_gb=80),
|
||||
# TODO: [ROCm] Fix pickle issue with ROCm spawn and tp>1
|
||||
pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"ROCm: Model too large for single GPU; "
|
||||
"multi-GPU blocked by HF _LazyConfigMapping pickle issue with spawn"
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
|
||||
@ -39,7 +39,7 @@ models = [MODEL_NAME]
|
||||
def granite_speech_attention_config():
|
||||
"""Return attention config for Granite Speech tests on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
return {"backend": "TRITON_ATTN"}
|
||||
return {"backend": "ROCM_AITER_FA"}
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM pooling tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def siglip_attention_config():
|
||||
"""Return attention config for SigLIP tests on ROCm.
|
||||
|
||||
On ROCm, SigLIP tests require FLEX_ATTENTION backend.
|
||||
"""
|
||||
if current_platform.is_rocm():
|
||||
return {"backend": "FLEX_ATTENTION"}
|
||||
return None
|
||||
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal import MultiModalKwargsItems
|
||||
@ -36,6 +37,7 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -52,6 +54,8 @@ DYNAMIC_ARG_DIMS = {
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self):
|
||||
@ -345,8 +349,29 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
|
||||
num_image_patches = kwargs.pop("num_image_patches")
|
||||
kwargs.pop("token_type_ids", None) # used only in `forward`
|
||||
|
||||
if pixel_values is not None:
|
||||
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
||||
# ROCm: Force math SDP backend for vision encoder to avoid accuracy issues
|
||||
# with flash_sdp and mem_efficient_sdp
|
||||
if current_platform.is_rocm():
|
||||
# TODO: [ROCm] Fix accuracy issues with flash backend
|
||||
logger.debug(
|
||||
"ROCm platform detected. Forcing math SDP backend "
|
||||
"for vision encoder. Currently ROCm platform has "
|
||||
"accuracy issues with `flash_sdp` and"
|
||||
"`mem_efficient_sdp` backends. See issue: "
|
||||
"https://github.com/vllm-project/vllm/issues/30167"
|
||||
)
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
backends=[torch.nn.attention.SDPBackend.MATH]
|
||||
):
|
||||
vision_embeddings = self.model.get_image_features(
|
||||
pixel_values, **kwargs
|
||||
)
|
||||
else:
|
||||
vision_embeddings = self.model.get_image_features(
|
||||
pixel_values, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
if vision_embeddings.ndim == 2:
|
||||
@ -364,6 +389,11 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
]
|
||||
|
||||
return vision_embeddings
|
||||
else:
|
||||
logger.debug(
|
||||
"No pixel values or image embeddings provided for multimodal embedding."
|
||||
)
|
||||
return None
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
@ -204,7 +205,7 @@ class RocmPlatform(Platform):
|
||||
assert block_size == 1, (
|
||||
"Sparse MLA backend on ROCm only supports block size 1 for now."
|
||||
)
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
logger.info_once("Using Sparse MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
|
||||
|
||||
if attn_selector_config.use_mla:
|
||||
@ -239,16 +240,16 @@ class RocmPlatform(Platform):
|
||||
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
if on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -257,25 +258,25 @@ class RocmPlatform(Platform):
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Handle automatic backend selection based on environment variables
|
||||
if selected_backend is None:
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
# Priority 4: Check for AITER enabled without specific flags
|
||||
@ -285,11 +286,19 @@ class RocmPlatform(Platform):
|
||||
and on_gfx9()
|
||||
and envs.VLLM_ROCM_USE_AITER_MHA is not False
|
||||
):
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 5: If model is Encoder-only self-attention type
|
||||
if (
|
||||
attn_selector_config.attn_type is not None
|
||||
and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
raise RuntimeError(
|
||||
@ -324,14 +333,19 @@ class RocmPlatform(Platform):
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
logger.info_once("Using AITER Flash Attention backend for ViT model.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
if (
|
||||
on_gfx9()
|
||||
and find_spec("flash_attn") is not None
|
||||
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
):
|
||||
logger.info_once("Using Flash Attention backend for ViT model.")
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
logger.info_once("Using Torch SDPA backend for ViT model.")
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user