From b41fb9d3b10dcf187ac0501ca80ede96d387617f Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:53:57 -0800 Subject: [PATCH] [Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982) Signed-off-by: Sourashis Roy --- tests/encoder_decoder/test_e2e_correctness.py | 9 +- .../vision_language/test_mllama.py | 100 +++++++++++------- tests/test_config.py | 2 + vllm/model_executor/models/mllama.py | 52 ++++++--- vllm/worker/enc_dec_model_runner.py | 34 ++---- 5 files changed, 117 insertions(+), 80 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index f2d7e9fd78cf..fa5d6a69a9bc 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple import pytest from transformers import AutoModelForSeq2SeqLM -from vllm.attention.selector import (_Backend, +from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs @@ -34,6 +34,13 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +@pytest.fixture(autouse=True) +def clear_cache(): + """Fixture to clear backend cache before each test.""" + _cached_get_attn_backend.cache_clear() # Clear the cache + yield # This allows the test to run + + @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 7f82347841cd..a3b1c0950d9a 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -4,6 +4,8 @@ import pytest from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) +from vllm.attention.selector import (_Backend, _cached_get_attn_backend, + global_force_attn_backend_context_manager) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -14,6 +16,8 @@ from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 3 +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] + HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "<|image|><|begin_of_text|>The meaning of the image is", @@ -221,6 +225,13 @@ def _run_test( ) +@pytest.fixture(autouse=True) +def clear_cache(): + """Fixture to clear backend cache before each test.""" + _cached_get_attn_backend.cache_clear() # Clear the cache + yield # This allows the test to run + + @large_gpu_test(min_gb=48) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( @@ -244,20 +255,26 @@ def _run_test( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, - num_logprobs) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - sizes=sizes, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) + num_logprobs, + attn_backend: _Backend) -> None: + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) @large_gpu_test(min_gb=48) @@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, - num_logprobs) -> None: + model, dtype, max_tokens, num_logprobs, + attn_backend: _Backend) -> None: stop_sign = image_assets[0].pil_image cherry_blossom = image_assets[1].pil_image @@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, cherry_blossom.resize((512, 1024)), ], ])] - - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) @large_gpu_test(min_gb=48) @@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, - dtype, max_tokens, num_logprobs) -> None: + dtype, max_tokens, num_logprobs, + attn_backend: _Backend) -> None: stop_sign = image_assets[0].pil_image cherry_blossom = image_assets[1].pil_image @@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, [stop_sign], [stop_sign, cherry_blossom], ])] - - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/test_config.py b/tests/test_config.py index 36c426d6c51f..df382d22d83e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -243,6 +243,8 @@ def test_rope_customization(): assert longchat_model_config.max_model_len == 4096 +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Encoder Decoder models not supported on ROCm.") @pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ ("facebook/opt-125m", False), ("facebook/bart-base", True), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index e5c1d28e6e7e..db7ee7b2d853 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import ( import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.xformers import XFormersMetadata from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module): q = self.q_norm(q) if attention_mask is not None: - output = self.attention_with_mask(q, k, v, kv_cache, - attention_mask, - kv_range_for_decode, - attn_metadata) + output = self._attention_with_mask(q, k, v, kv_cache, + attention_mask, + kv_range_for_decode, + attn_metadata) else: - output = self.attn(q, + output = self.attn(q.view(-1, + self.num_local_heads * self.head_dim), k, v, kv_cache, @@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module): out, _ = self.o_proj(output) return out - def attention_with_mask( + def _attention_with_mask( self, q: torch.Tensor, k: torch.Tensor, @@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module): attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. - if len(kv_cache.shape) == 3: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_local_key_value_heads, self.head_dim) - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - PagedAttention.write_to_paged_cache( - cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + if len(kv_cache.shape) > 1: + if isinstance(attn_metadata, FlashAttentionMetadata): + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + torch.ops._C_cache_ops.reshape_and_cache_flash( + cached_k, + cached_v, + kv_cache[0], + kv_cache[1], + attn_metadata. + cross_slot_mapping, # type: ignore[union-attr] + "auto", + 1.0, + 1.0, + ) + elif isinstance(attn_metadata, XFormersMetadata): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + else: + raise ValueError( + f"Unsupported AttentionMetadata {type(attn_metadata)} " + f"class found. Expected the AttentionMetadata to " + f"be either XFormersMetadata or FlashAttentionMetadata.") + # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a # standard causal mask, neither a block diagonal mask which diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 008e0c974599..82824faa6629 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, - get_global_forced_attn_backend, - global_force_attn_backend) -from vllm.config import ModelConfig, VllmConfig + get_global_forced_attn_backend) +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.sampling_params import SamplingParams @@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario logger = init_logger(__name__) -# The Mllama model has PagedAttention specific logic because of which it -# can only be run with the XFORMERS backend -# TODO Make Mllama model work with Flash Attention backend. -_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] - @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): models) but these arguments are present here for compatibility with the base-class constructor. ''' - self._maybe_force_supported_attention_backend(vllm_config.model_config) + self._maybe_force_supported_attention_backend() super().__init__( vllm_config=vllm_config, @@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _is_xformers_only_encoder_decoder_model(self, - model: ModelConfig) -> bool: - return get_architecture_class_name( - model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS - - def _maybe_force_supported_attention_backend(self, model: ModelConfig): + def _maybe_force_supported_attention_backend(self): ''' Force vLLM to use the XFormers attention backend, which is currently the only supported option. @@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): maybe_global_forced_backend = get_global_forced_attn_backend() is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None - - if not (is_forced_by_global or is_forced_by_env_var) \ - and self._is_xformers_only_encoder_decoder_model(model): - # The user has not already specified an attention backend - # override - logger.info( - "Encoder-Decoder Model Architecture %s requires XFormers " - "backend; overriding backend auto-selection and " - "forcing XFormers.", get_architecture_class_name(model)) - global_force_attn_backend(_Backend.XFORMERS) - elif is_forced_by_global: + if is_forced_by_global: # noqa: SIM102 # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. if maybe_global_forced_backend not in\ [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() - elif is_forced_by_env_var: + elif is_forced_by_env_var: # noqa: SIM102 # Backend override enforced by vLLM backend # environment variable if maybe_env_var_forced_backend not in\