mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)
Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
parent
7c65527918
commit
b41fb9d3b1
@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -34,6 +34,13 @@ def vllm_to_hf_output(
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Fixture to clear backend cache before each test."""
|
||||
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||
yield # This allows the test to run
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
|
||||
@ -4,6 +4,8 @@ import pytest
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|image|><|begin_of_text|>The meaning of the image is",
|
||||
@ -221,6 +225,13 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Fixture to clear backend cache before each test."""
|
||||
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||
yield # This allows the test to run
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
@ -244,20 +255,26 @@ def _run_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
model, sizes, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
model, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
model, dtype, max_tokens, num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
cherry_blossom.resize((512, 1024)),
|
||||
],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
dtype, max_tokens, num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
[stop_sign],
|
||||
[stop_sign, cherry_blossom],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -243,6 +243,8 @@ def test_rope_customization():
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Encoder Decoder models not supported on ROCm.")
|
||||
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
|
||||
("facebook/opt-125m", False),
|
||||
("facebook/bart-base", True),
|
||||
|
||||
@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
|
||||
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
|
||||
if attention_mask is not None:
|
||||
output = self.attention_with_mask(q, k, v, kv_cache,
|
||||
attention_mask,
|
||||
kv_range_for_decode,
|
||||
attn_metadata)
|
||||
output = self._attention_with_mask(q, k, v, kv_cache,
|
||||
attention_mask,
|
||||
kv_range_for_decode,
|
||||
attn_metadata)
|
||||
else:
|
||||
output = self.attn(q,
|
||||
output = self.attn(q.view(-1,
|
||||
self.num_local_heads * self.head_dim),
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
out, _ = self.o_proj(output)
|
||||
return out
|
||||
|
||||
def attention_with_mask(
|
||||
def _attention_with_mask(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache.shape) == 3:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
PagedAttention.write_to_paged_cache(
|
||||
cached_k, cached_v, key_cache, value_cache,
|
||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||
if len(kv_cache.shape) > 1:
|
||||
if isinstance(attn_metadata, FlashAttentionMetadata):
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
cached_k,
|
||||
cached_v,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_metadata.
|
||||
cross_slot_mapping, # type: ignore[union-attr]
|
||||
"auto",
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
elif isinstance(attn_metadata, XFormersMetadata):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
PagedAttention.write_to_paged_cache(
|
||||
cached_k, cached_v, key_cache, value_cache,
|
||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported AttentionMetadata {type(attn_metadata)} "
|
||||
f"class found. Expected the AttentionMetadata to "
|
||||
f"be either XFormersMetadata or FlashAttentionMetadata.")
|
||||
|
||||
# We have to call torch.sdpa for prefill when using a
|
||||
# custom cross-attention mask. Because the mask is not a
|
||||
# standard causal mask, neither a block diagonal mask which
|
||||
|
||||
@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend,
|
||||
global_force_attn_backend)
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.utils import get_architecture_class_name
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The Mllama model has PagedAttention specific logic because of which it
|
||||
# can only be run with the XFORMERS backend
|
||||
# TODO Make Mllama model work with Flash Attention backend.
|
||||
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
'''
|
||||
self._maybe_force_supported_attention_backend(vllm_config.model_config)
|
||||
self._maybe_force_supported_attention_backend()
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
# Crash for unsupported encoder/scenarios
|
||||
assert_enc_dec_mr_supported_scenario(self)
|
||||
|
||||
def _is_xformers_only_encoder_decoder_model(self,
|
||||
model: ModelConfig) -> bool:
|
||||
return get_architecture_class_name(
|
||||
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
|
||||
|
||||
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
|
||||
def _maybe_force_supported_attention_backend(self):
|
||||
'''
|
||||
Force vLLM to use the XFormers attention backend,
|
||||
which is currently the only supported option.
|
||||
@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
maybe_global_forced_backend = get_global_forced_attn_backend()
|
||||
is_forced_by_global = maybe_global_forced_backend is not None
|
||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||
|
||||
if not (is_forced_by_global or is_forced_by_env_var) \
|
||||
and self._is_xformers_only_encoder_decoder_model(model):
|
||||
# The user has not already specified an attention backend
|
||||
# override
|
||||
logger.info(
|
||||
"Encoder-Decoder Model Architecture %s requires XFormers "
|
||||
"backend; overriding backend auto-selection and "
|
||||
"forcing XFormers.", get_architecture_class_name(model))
|
||||
global_force_attn_backend(_Backend.XFORMERS)
|
||||
elif is_forced_by_global:
|
||||
if is_forced_by_global: # noqa: SIM102
|
||||
# Backend override enforced by global variable takes
|
||||
# precedence over vLLM backend environment variable.
|
||||
if maybe_global_forced_backend not in\
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||
raise_backend_err()
|
||||
elif is_forced_by_env_var:
|
||||
elif is_forced_by_env_var: # noqa: SIM102
|
||||
# Backend override enforced by vLLM backend
|
||||
# environment variable
|
||||
if maybe_env_var_forced_backend not in\
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user