[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745 2024-11-12 10:53:57 -08:00 committed by GitHub
parent 7c65527918
commit b41fb9d3b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 80 deletions

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.attention.selector import (_Backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
@ -34,6 +34,13 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs return output_ids, hf_output_str, out_logprobs
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)

View File

@ -4,6 +4,8 @@ import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding) BatchEncoding)
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3 _LIMIT_IMAGE_PER_PROMPT = 3
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is", "<|image|><|begin_of_text|>The meaning of the image is",
@ -221,6 +225,13 @@ def _run_test(
) )
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -244,20 +255,26 @@ def _run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens, model, sizes, dtype, max_tokens,
num_logprobs) -> None: num_logprobs,
run_test( attn_backend: _Backend) -> None:
hf_runner, with global_force_attn_backend_context_manager(attn_backend):
vllm_runner, if attn_backend == _Backend.FLASH_ATTN:
image_assets, # Flash Attention works only with bfloat16 data-type
model, dtype = 'bfloat16'
sizes=sizes, run_test(
dtype=dtype, hf_runner,
max_tokens=max_tokens, vllm_runner,
num_logprobs=num_logprobs, image_assets,
tensor_parallel_size=1, model,
) sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens, model, dtype, max_tokens, num_logprobs,
num_logprobs) -> None: attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image cherry_blossom = image_assets[1].pil_image
@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom.resize((512, 1024)), cherry_blossom.resize((512, 1024)),
], ],
])] ])]
with global_force_attn_backend_context_manager(attn_backend):
_run_test( if attn_backend == _Backend.FLASH_ATTN:
hf_runner, # Flash Attention works only with bfloat16 data-type
vllm_runner, dtype = 'bfloat16'
inputs, _run_test(
model, hf_runner,
dtype=dtype, vllm_runner,
max_tokens=max_tokens, inputs,
num_logprobs=num_logprobs, model,
tensor_parallel_size=1, dtype=dtype,
) max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None: dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image cherry_blossom = image_assets[1].pil_image
@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[stop_sign], [stop_sign],
[stop_sign, cherry_blossom], [stop_sign, cherry_blossom],
])] ])]
with global_force_attn_backend_context_manager(attn_backend):
_run_test( if attn_backend == _Backend.FLASH_ATTN:
hf_runner, # Flash Attention works only with bfloat16 data-type
vllm_runner, dtype = 'bfloat16'
inputs, _run_test(
model, hf_runner,
dtype=dtype, vllm_runner,
max_tokens=max_tokens, inputs,
num_logprobs=num_logprobs, model,
tensor_parallel_size=1, dtype=dtype,
) max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -243,6 +243,8 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096 assert longchat_model_config.max_model_len == 4096
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ @pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False), ("facebook/opt-125m", False),
("facebook/bart-base", True), ("facebook/bart-base", True),

View File

@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
if attention_mask is not None: if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache, output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask, attention_mask,
kv_range_for_decode, kv_range_for_decode,
attn_metadata) attn_metadata)
else: else:
output = self.attn(q, output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k, k,
v, v,
kv_cache, kv_cache,
@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
def attention_with_mask( def _attention_with_mask(
self, self,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3: if len(kv_cache.shape) > 1:
key_cache, value_cache = PagedAttention.split_kv_cache( if isinstance(attn_metadata, FlashAttentionMetadata):
kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) torch.ops._C_cache_ops.reshape_and_cache_flash(
PagedAttention.write_to_paged_cache( cached_k,
cached_k, cached_v, key_cache, value_cache, cached_v,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")
# We have to call torch.sdpa for prefill when using a # We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a # custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which # standard causal mask, neither a block diagonal mask which

View File

@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend, get_global_forced_attn_backend)
global_force_attn_backend) from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__) logger = init_logger(__name__)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with models) but these arguments are present here for compatibility with
the base-class constructor. the base-class constructor.
''' '''
self._maybe_force_supported_attention_backend(vllm_config.model_config) self._maybe_force_supported_attention_backend()
super().__init__( super().__init__(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios # Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self) assert_enc_dec_mr_supported_scenario(self)
def _is_xformers_only_encoder_decoder_model(self, def _maybe_force_supported_attention_backend(self):
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
''' '''
Force vLLM to use the XFormers attention backend, Force vLLM to use the XFormers attention backend,
which is currently the only supported option. which is currently the only supported option.
@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
maybe_global_forced_backend = get_global_forced_attn_backend() maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None
if is_forced_by_global: # noqa: SIM102
if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes # Backend override enforced by global variable takes
# precedence over vLLM backend environment variable. # precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\ if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]: [_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err() raise_backend_err()
elif is_forced_by_env_var: elif is_forced_by_env_var: # noqa: SIM102
# Backend override enforced by vLLM backend # Backend override enforced by vLLM backend
# environment variable # environment variable
if maybe_env_var_forced_backend not in\ if maybe_env_var_forced_backend not in\