mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 21:35:46 +08:00
[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)
Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
parent
7c65527918
commit
b41fb9d3b1
@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
from vllm.attention.selector import (_Backend,
|
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||||
global_force_attn_backend_context_manager)
|
global_force_attn_backend_context_manager)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
@ -34,6 +34,13 @@ def vllm_to_hf_output(
|
|||||||
return output_ids, hf_output_str, out_logprobs
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear backend cache before each test."""
|
||||||
|
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||||
|
yield # This allows the test to run
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import pytest
|
|||||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||||
BatchEncoding)
|
BatchEncoding)
|
||||||
|
|
||||||
|
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||||
|
global_force_attn_backend_context_manager)
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
|
|||||||
|
|
||||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||||
|
|
||||||
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||||
|
|
||||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
"stop_sign":
|
"stop_sign":
|
||||||
"<|image|><|begin_of_text|>The meaning of the image is",
|
"<|image|><|begin_of_text|>The meaning of the image is",
|
||||||
@ -221,6 +225,13 @@ def _run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear backend cache before each test."""
|
||||||
|
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||||
|
yield # This allows the test to run
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -244,20 +255,26 @@ def _run_test(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||||
model, sizes, dtype, max_tokens,
|
model, sizes, dtype, max_tokens,
|
||||||
num_logprobs) -> None:
|
num_logprobs,
|
||||||
run_test(
|
attn_backend: _Backend) -> None:
|
||||||
hf_runner,
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
vllm_runner,
|
if attn_backend == _Backend.FLASH_ATTN:
|
||||||
image_assets,
|
# Flash Attention works only with bfloat16 data-type
|
||||||
model,
|
dtype = 'bfloat16'
|
||||||
sizes=sizes,
|
run_test(
|
||||||
dtype=dtype,
|
hf_runner,
|
||||||
max_tokens=max_tokens,
|
vllm_runner,
|
||||||
num_logprobs=num_logprobs,
|
image_assets,
|
||||||
tensor_parallel_size=1,
|
model,
|
||||||
)
|
sizes=sizes,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||||
model, dtype, max_tokens,
|
model, dtype, max_tokens, num_logprobs,
|
||||||
num_logprobs) -> None:
|
attn_backend: _Backend) -> None:
|
||||||
|
|
||||||
stop_sign = image_assets[0].pil_image
|
stop_sign = image_assets[0].pil_image
|
||||||
cherry_blossom = image_assets[1].pil_image
|
cherry_blossom = image_assets[1].pil_image
|
||||||
@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
|||||||
cherry_blossom.resize((512, 1024)),
|
cherry_blossom.resize((512, 1024)),
|
||||||
],
|
],
|
||||||
])]
|
])]
|
||||||
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
_run_test(
|
if attn_backend == _Backend.FLASH_ATTN:
|
||||||
hf_runner,
|
# Flash Attention works only with bfloat16 data-type
|
||||||
vllm_runner,
|
dtype = 'bfloat16'
|
||||||
inputs,
|
_run_test(
|
||||||
model,
|
hf_runner,
|
||||||
dtype=dtype,
|
vllm_runner,
|
||||||
max_tokens=max_tokens,
|
inputs,
|
||||||
num_logprobs=num_logprobs,
|
model,
|
||||||
tensor_parallel_size=1,
|
dtype=dtype,
|
||||||
)
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||||
dtype, max_tokens, num_logprobs) -> None:
|
dtype, max_tokens, num_logprobs,
|
||||||
|
attn_backend: _Backend) -> None:
|
||||||
|
|
||||||
stop_sign = image_assets[0].pil_image
|
stop_sign = image_assets[0].pil_image
|
||||||
cherry_blossom = image_assets[1].pil_image
|
cherry_blossom = image_assets[1].pil_image
|
||||||
@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
|||||||
[stop_sign],
|
[stop_sign],
|
||||||
[stop_sign, cherry_blossom],
|
[stop_sign, cherry_blossom],
|
||||||
])]
|
])]
|
||||||
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
_run_test(
|
if attn_backend == _Backend.FLASH_ATTN:
|
||||||
hf_runner,
|
# Flash Attention works only with bfloat16 data-type
|
||||||
vllm_runner,
|
dtype = 'bfloat16'
|
||||||
inputs,
|
_run_test(
|
||||||
model,
|
hf_runner,
|
||||||
dtype=dtype,
|
vllm_runner,
|
||||||
max_tokens=max_tokens,
|
inputs,
|
||||||
num_logprobs=num_logprobs,
|
model,
|
||||||
tensor_parallel_size=1,
|
dtype=dtype,
|
||||||
)
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|||||||
@ -243,6 +243,8 @@ def test_rope_customization():
|
|||||||
assert longchat_model_config.max_model_len == 4096
|
assert longchat_model_config.max_model_len == 4096
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Encoder Decoder models not supported on ROCm.")
|
||||||
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
|
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
|
||||||
("facebook/opt-125m", False),
|
("facebook/opt-125m", False),
|
||||||
("facebook/bart-base", True),
|
("facebook/bart-base", True),
|
||||||
|
|||||||
@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
|
|
||||||
import vllm.distributed.parallel_state as ps
|
import vllm.distributed.parallel_state as ps
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.attention.backends.xformers import XFormersMetadata
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
output = self.attention_with_mask(q, k, v, kv_cache,
|
output = self._attention_with_mask(q, k, v, kv_cache,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
kv_range_for_decode,
|
kv_range_for_decode,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
else:
|
else:
|
||||||
output = self.attn(q,
|
output = self.attn(q.view(-1,
|
||||||
|
self.num_local_heads * self.head_dim),
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
out, _ = self.o_proj(output)
|
out, _ = self.o_proj(output)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def attention_with_mask(
|
def _attention_with_mask(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Skip writing kv-cache for the initial profiling run.
|
# Skip writing kv-cache for the initial profiling run.
|
||||||
if len(kv_cache.shape) == 3:
|
if len(kv_cache.shape) > 1:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
if isinstance(attn_metadata, FlashAttentionMetadata):
|
||||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
PagedAttention.write_to_paged_cache(
|
cached_k,
|
||||||
cached_k, cached_v, key_cache, value_cache,
|
cached_v,
|
||||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
attn_metadata.
|
||||||
|
cross_slot_mapping, # type: ignore[union-attr]
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
elif isinstance(attn_metadata, XFormersMetadata):
|
||||||
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
|
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||||
|
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||||
|
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
cached_k, cached_v, key_cache, value_cache,
|
||||||
|
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported AttentionMetadata {type(attn_metadata)} "
|
||||||
|
f"class found. Expected the AttentionMetadata to "
|
||||||
|
f"be either XFormersMetadata or FlashAttentionMetadata.")
|
||||||
|
|
||||||
# We have to call torch.sdpa for prefill when using a
|
# We have to call torch.sdpa for prefill when using a
|
||||||
# custom cross-attention mask. Because the mask is not a
|
# custom cross-attention mask. Because the mask is not a
|
||||||
# standard causal mask, neither a block diagonal mask which
|
# standard causal mask, neither a block diagonal mask which
|
||||||
|
|||||||
@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
|||||||
AttentionMetadata)
|
AttentionMetadata)
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||||
get_global_forced_attn_backend,
|
get_global_forced_attn_backend)
|
||||||
global_force_attn_backend)
|
from vllm.config import VllmConfig
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.utils import get_architecture_class_name
|
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||||
MultiModalRegistry)
|
MultiModalRegistry)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# The Mllama model has PagedAttention specific logic because of which it
|
|
||||||
# can only be run with the XFORMERS backend
|
|
||||||
# TODO Make Mllama model work with Flash Attention backend.
|
|
||||||
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||||
@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
models) but these arguments are present here for compatibility with
|
models) but these arguments are present here for compatibility with
|
||||||
the base-class constructor.
|
the base-class constructor.
|
||||||
'''
|
'''
|
||||||
self._maybe_force_supported_attention_backend(vllm_config.model_config)
|
self._maybe_force_supported_attention_backend()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
# Crash for unsupported encoder/scenarios
|
# Crash for unsupported encoder/scenarios
|
||||||
assert_enc_dec_mr_supported_scenario(self)
|
assert_enc_dec_mr_supported_scenario(self)
|
||||||
|
|
||||||
def _is_xformers_only_encoder_decoder_model(self,
|
def _maybe_force_supported_attention_backend(self):
|
||||||
model: ModelConfig) -> bool:
|
|
||||||
return get_architecture_class_name(
|
|
||||||
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
|
|
||||||
|
|
||||||
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
|
|
||||||
'''
|
'''
|
||||||
Force vLLM to use the XFormers attention backend,
|
Force vLLM to use the XFormers attention backend,
|
||||||
which is currently the only supported option.
|
which is currently the only supported option.
|
||||||
@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
maybe_global_forced_backend = get_global_forced_attn_backend()
|
maybe_global_forced_backend = get_global_forced_attn_backend()
|
||||||
is_forced_by_global = maybe_global_forced_backend is not None
|
is_forced_by_global = maybe_global_forced_backend is not None
|
||||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||||
|
if is_forced_by_global: # noqa: SIM102
|
||||||
if not (is_forced_by_global or is_forced_by_env_var) \
|
|
||||||
and self._is_xformers_only_encoder_decoder_model(model):
|
|
||||||
# The user has not already specified an attention backend
|
|
||||||
# override
|
|
||||||
logger.info(
|
|
||||||
"Encoder-Decoder Model Architecture %s requires XFormers "
|
|
||||||
"backend; overriding backend auto-selection and "
|
|
||||||
"forcing XFormers.", get_architecture_class_name(model))
|
|
||||||
global_force_attn_backend(_Backend.XFORMERS)
|
|
||||||
elif is_forced_by_global:
|
|
||||||
# Backend override enforced by global variable takes
|
# Backend override enforced by global variable takes
|
||||||
# precedence over vLLM backend environment variable.
|
# precedence over vLLM backend environment variable.
|
||||||
if maybe_global_forced_backend not in\
|
if maybe_global_forced_backend not in\
|
||||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||||
raise_backend_err()
|
raise_backend_err()
|
||||||
elif is_forced_by_env_var:
|
elif is_forced_by_env_var: # noqa: SIM102
|
||||||
# Backend override enforced by vLLM backend
|
# Backend override enforced by vLLM backend
|
||||||
# environment variable
|
# environment variable
|
||||||
if maybe_env_var_forced_backend not in\
|
if maybe_env_var_forced_backend not in\
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user