[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745 2024-11-12 10:53:57 -08:00 committed by GitHub
parent 7c65527918
commit b41fb9d3b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 80 deletions

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
@ -34,6 +34,13 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)

View File

@ -4,6 +4,8 @@ import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
@ -221,6 +225,13 @@ def _run_test(
)
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
@ -244,20 +255,26 @@ def _run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
num_logprobs,
attn_backend: _Backend) -> None:
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
model, dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom.resize((512, 1024)),
],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:
dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[stop_sign],
[stop_sign, cherry_blossom],
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -243,6 +243,8 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),

View File

@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q)
if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
else:
output = self.attn(q,
output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k,
v,
kv_cache,
@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
out, _ = self.o_proj(output)
return out
def attention_with_mask(
def _attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which

View File

@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import ModelConfig, VllmConfig
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend(vllm_config.model_config)
self._maybe_force_supported_attention_backend()
super().__init__(
vllm_config=vllm_config,
@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)
def _is_xformers_only_encoder_decoder_model(self,
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None
if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
if is_forced_by_global: # noqa: SIM102
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
elif is_forced_by_env_var:
elif is_forced_by_env_var: # noqa: SIM102
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend not in\