[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)

This commit is contained in:
sroy745 2024-11-01 23:22:49 -07:00 committed by GitHub
parent d522034c85
commit a78dd3303e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 715 additions and 316 deletions

View File

@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ..conftest import DecoderPromptType from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
LIST_ENC_DEC_SUPPORTED_BACKENDS = [
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
]
def vllm_to_hf_output( def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
@ -29,7 +35,8 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs: int, num_logprobs: int,
decoder_prompt_type: DecoderPromptType, decoder_prompt_type: DecoderPromptType,
enforce_eager: bool, enforce_eager: bool,
attn_backend: _Backend,
) -> None: ) -> None:
''' '''
End-to-End (E2E) test for the encoder-decoder framework. End-to-End (E2E) test for the encoder-decoder framework.
@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent implementations to ensure that both implementations produce consistent
and correct results. and correct results.
''' '''
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]
# Configuration settings for HF baseline # Configuration settings for HF baseline
hf_kwargs = { hf_kwargs = {
"top_k": None, "top_k": None,
"num_beams": 1, "num_beams": 1,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
"top_p": 1.0, "top_p": 1.0,
"length_penalty": 1.0, "length_penalty": 1.0,
"early_stopping": False, "early_stopping": False,
"no_repeat_ngram_size": None, "no_repeat_ngram_size": None,
"min_length": 0 "min_length": 0
} }
with hf_runner(model, dtype=dtype, with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model: auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( hf_outputs = (
test_case_prompts, hf_model.generate_encoder_decoder_greedy_logprobs_limit(
max_tokens, test_case_prompts,
num_logprobs, max_tokens,
**hf_kwargs, num_logprobs,
)) **hf_kwargs,
with vllm_runner(model, dtype=dtype, ))
enforce_eager=enforce_eager) as vllm_model: with vllm_runner(model, dtype=dtype,
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( enforce_eager=enforce_eager) as vllm_model:
test_case_prompts, max_tokens, num_logprobs) vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
hf_skip_tokens = (1 hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
if decoder_prompt_type == DecoderPromptType.NONE else 0) else 0)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=[ outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type) vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs for vllm_output in vllm_outputs
], ],
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens, num_outputs_0_skip_tokens=hf_skip_tokens,
) )

View File

@ -16,13 +16,13 @@ from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType) AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, from vllm.attention.selector import (_Backend, get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HEAD_SIZES = [64, 256] HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16] NUM_HEADS = [1, 16]
@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt.num_heads, test_pt.num_heads,
test_pt.head_size, test_pt.head_size,
test_pt.block_size, test_pt.block_size,
device=CUDA_DEVICE) device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache) return TestResources(scale, attn_backend, attn, kv_cache)
@ -592,6 +593,7 @@ def _run_encoder_attention_test(
attn: Attention, attn: Attention,
encoder_test_params: PhaseTestParameters, encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder attention. Run encoder attention.
@ -610,6 +612,8 @@ def _run_encoder_attention_test(
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
query/key/value fields query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention * attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns: Returns:
* Attention.forward() applied to packed {query,key,value} and * Attention.forward() applied to packed {query,key,value} and
@ -619,20 +623,31 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
return attn.forward(packed_qkv.query, with set_forward_context(attn_metadata):
packed_qkv.key, # In the test setup the shape of the query is
packed_qkv.value, # [batch_size, seq_len, num_heads, head_size]. However
torch.tensor([], # the attention backend expect the shape to be
dtype=torch.float32, # [num_tokens, hidden_size]. Hence reshape the query before
device=packed_qkv.query.device), # invoking the forward method.
attn_metadata, # TODO - Update the way we construct the query so that it
attn_type=attn_type) # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
def _run_decoder_self_attention_test( def _run_decoder_self_attention_test(
test_rsrcs: TestResources, test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters, decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run decoder self-attention test. Run decoder self-attention test.
@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
query/key/value fields query/key/value fields
* attn_metadata: attention metadata for decoder-self attention * attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping) (contains KV cache memory-mapping)
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns: Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache * Attention.forward() applied to packed_{query,key,value}, kv_cache
@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
return attn.forward(packed_qkv.query, with set_forward_context(attn_metadata):
packed_qkv.key, # In the test setup the shape of the query is
packed_qkv.value, # [batch_size, seq_len, num_heads, head_size]. However
kv_cache, # the attention backend expect the shape to be
attn_metadata, # [num_tokens, hidden_size]. Hence reshape the query before
attn_type=attn_type) # invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
def _run_encoder_decoder_cross_attention_test( def _run_encoder_decoder_cross_attention_test(
@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
decoder_test_params: PhaseTestParameters, decoder_test_params: PhaseTestParameters,
cross_test_params: Optional[PhaseTestParameters], cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder/decoder cross-attention test. Run encoder/decoder cross-attention test.
@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
key/value fields key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention * attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns: Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache * Attention.forward() applied to packed_{query,key,value}, kv_cache
@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, with set_forward_context(attn_metadata):
key, # In the test setup the shape of the query is
value, # [batch_size, seq_len, num_heads, head_size]. However
kv_cache, # the attention backend expect the shape to be
attn_metadata, # [num_tokens, hidden_size]. Hence reshape the query before
attn_type=attn_type) # invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
@pytest.fixture(autouse=True)
def set_reset_environment(attn_backend):
# Set the default torch datatype to bfloat16 to enable
# testing of the Flash Attention backend. Also clear the
# cached value of the backend.
default_dtype = torch.get_default_dtype()
if attn_backend.name == 'FLASH_ATTN':
torch.set_default_dtype(torch.bfloat16)
get_attn_backend.cache_clear()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
torch.set_default_dtype(default_dtype)
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
@ -773,10 +828,8 @@ def test_encoder_only(
* max_dec_seq_len: max length of decoder input sequences * max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences * max_enc_seq_len: max length of encoder input sequences
''' '''
# Force Attention wrapper backend # Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend): with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally # Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size # to be more than necessary, since exceeding the kv cache size
# is not part of this test # is not part of this test
@ -807,10 +860,14 @@ def test_encoder_only(
# PREFILL: encoder attention # PREFILL: encoder attention
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt))
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
* max_dec_seq_len: max length of decoder input sequences * max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences * max_enc_seq_len: max length of encoder input sequences
''' '''
# Force Attention wrapper backend # Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend): with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally # Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size # to be more than necessary, since exceeding the kv cache size
# is not part of this test # is not part of this test
@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata) prephase_attn_metadata,
test_pt=test_pt)
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)
# PREFILL: decoder self-attention test # PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test( prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt)
# - Is prefill decoder self-attention correct? # - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params, assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out) prephase_dec_pckd_act_out,
attn_backend.name)
# PREFILL: encoder/decoder cross-attention test # PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, test_rsrcs,
prephase_attn_metadata) prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt)
# - Is prefill encoder/decoder cross-attention correct? # - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params, assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out) prephase_cross_pckd_act_out,
attn_backend.name)
# DECODE: build decode-phase attention metadata # DECODE: build decode-phase attention metadata
@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test # DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test( decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt)
# - Is decode-phase decoder self-attention correct? # - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params, assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out) decphase_dec_pckd_act_out,
attn_backend.name)
# DECODE: encoder/decoder cross-attention test # DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt)
# - Is decode-phase encoder/decoder cross-attention correct? # - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params, assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out) decphase_cross_pckd_act_out,
attn_backend.name)

View File

@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
# For now, disable "test_aot_dispatch_dynamic" since there are some # For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4. # bugs related to this test in PyTorch 2.4.
@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
if backend_name == STR_XFORMERS_ATTN_VAL: if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.xformers import XFormersBackend
return XFormersBackend() return XFormersBackend()
elif backend_name == STR_FLASH_ATTN_VAL:
from vllm.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend()
raise AssertionError( raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test") f"Unrecognized backend_name {backend_name} for unit test")
def _make_metadata_tensors( def _make_metadata_tensors(
seq_lens: Optional[List[int]], context_lens: Optional[List[int]], seq_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] context_lens: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], encoder_seq_lens: Optional[List[int]],
torch.Tensor, Optional[int]]: device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]:
''' '''
Build scalar & tensor values required to build attention metadata structure. Build scalar & tensor values required to build attention metadata structure.
@ -553,6 +558,8 @@ def _make_metadata_tensors(
* max_context_len: max(context_lens) * max_context_len: max(context_lens)
* max_seq_len: max(seq_lens) * max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence * seq_start_loc: start idx of each sequence
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
* encoder_seq_start_loc: start idx of each encoder sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor * max_encoder_seq_len: encoder seq_lens list, as tensor
''' '''
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
@ -566,8 +573,26 @@ def _make_metadata_tensors(
seq_start_loc = None seq_start_loc = None
if seq_lens_tensor is not None:
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=seq_lens_tensor.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=encoder_seq_lens_tensor.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
max_encoder_seq_len)
def make_kv_cache(num_blocks: int, def make_kv_cache(num_blocks: int,
@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
head_size: int, head_size: int,
block_size: int, block_size: int,
device: Union[torch.device, str], device: Union[torch.device, str],
backend: str,
default_val: float = 0.0) -> torch.Tensor: default_val: float = 0.0) -> torch.Tensor:
''' '''
Create a fake KV cache. Create a fake KV cache.
@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
Returns: Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
''' '''
if backend == 'XFORMERS':
kv_cache = torch.rand( kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device) (2, num_blocks, block_size * num_heads * head_size)).to(device)
elif backend == 'FLASH_ATTN':
kv_cache = torch.rand(
(2, num_blocks, block_size, num_heads, head_size)).to(device)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")
if default_val is not None: if default_val is not None:
kv_cache[:, :, :] = default_val kv_cache[:, :, :] = default_val
return kv_cache return kv_cache
@ -858,8 +894,9 @@ def make_test_metadata(
context_lens_tensor, context_lens_tensor,
_, _,
_, _,
_, seq_start_loc,
encoder_seq_lens_tensor, encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len, max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens, ) = _make_metadata_tensors(seq_lens,
context_lens, context_lens,
@ -874,6 +911,7 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0, max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
@ -882,6 +920,7 @@ def make_test_metadata(
num_encoder_tokens=num_encoder_tokens, num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens, encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len, max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping), cross_kv_mmap.slot_mapping),
@ -904,8 +943,9 @@ def make_test_metadata(
context_lens_tensor, context_lens_tensor,
_, _,
_, _,
_, seq_start_loc,
encoder_seq_lens_tensor, encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len, max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens, ) = _make_metadata_tensors(seq_lens,
context_lens, context_lens,
@ -920,14 +960,17 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens), max_decode_seq_len=max(seq_lens),
max_decode_query_len=1,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables, block_tables=kv_mmap.block_tables,
use_cuda_graph=False, use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens, num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens, encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len, max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping), cross_kv_mmap.slot_mapping),
@ -936,7 +979,8 @@ def make_test_metadata(
def assert_actual_matches_ideal(test_params: PhaseTestParameters, def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test: torch.Tensor) -> None: output_under_test: torch.Tensor,
backend: str) -> None:
''' '''
Assert that observed output matches the ideal output Assert that observed output matches the ideal output
contained in the test parameters data structure. contained in the test parameters data structure.
@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value * output_under_test: actually observed output value
''' '''
ideal_output = test_params.packed_qkvo.ideal_output ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output, if backend == 'XFORMERS':
output_under_test.view_as(ideal_output)) torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
elif backend == 'FLASH_ATTN':
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output),
atol=0.01,
rtol=0.016)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")
# Copied/modified from torch._refs.__init__.py # Copied/modified from torch._refs.__init__.py

View File

@ -85,7 +85,7 @@ def run_test(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,

View File

@ -10,10 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, from vllm.attention.backends.utils import (
compute_slot_mapping, PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
is_block_tables_empty) get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
@ -73,7 +74,6 @@ class FlashAttentionBackend(AttentionBackend):
src_key_cache = src_kv_cache[0] src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0] dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1] src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1] dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@ -85,6 +85,7 @@ class FlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill # Maximum sequence length among decode batch. 0 if there are prefill
# requests only. # requests only.
max_decode_seq_len: int max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed # (batch_size,) A tensor of context lengths (tokens that are computed
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor]
@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata):
# Whether or not if cuda graph is enabled. # Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return is_all_encoder_attn_metadata_set(self)
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return is_all_cross_attn_metadata_set(self)
@property @property
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_prefills == 0: if self.num_prefills == 0:
@ -159,32 +197,52 @@ class FlashAttentionMetadata(AttentionMetadata):
if self._cached_prefill_metadata is not None: if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata return self._cached_prefill_metadata
assert self.seq_lens is not None assert ((self.seq_lens is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens is not None))
assert self.query_start_loc is not None assert ((self.seq_lens_tensor is not None)
assert self.context_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
assert self.block_tables is not None
assert self.seq_start_loc is not None # Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
self._cached_prefill_metadata = FlashAttentionMetadata( self._cached_prefill_metadata = FlashAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=seq_lens,
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0, max_decode_query_len=0,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=query_start_loc,
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], seq_start_loc=seq_start_loc,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=context_lens_tensor,
block_tables=self.block_tables[:self.num_prefills], block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
@ -194,17 +252,25 @@ class FlashAttentionMetadata(AttentionMetadata):
if self._cached_decode_metadata is not None: if self._cached_decode_metadata is not None:
return self._cached_decode_metadata return self._cached_decode_metadata
assert self.block_tables is not None assert ((self.seq_lens_tensor is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
self._cached_decode_metadata = FlashAttentionMetadata( self._cached_decode_metadata = FlashAttentionMetadata(
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
@ -214,9 +280,15 @@ class FlashAttentionMetadata(AttentionMetadata):
seq_start_loc=self.seq_start_loc[self.num_prefills:] seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None, if self.seq_start_loc is not None else None,
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, def advance_step(self,
@ -586,16 +658,20 @@ class FlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
output = torch.ops.vllm.unified_flash_attention( output = torch.ops.vllm.unified_flash_attention(
query, query,
key, key,
@ -608,6 +684,7 @@ class FlashAttentionImpl(AttentionImpl):
k_scale, k_scale,
v_scale, v_scale,
self.scale, self.scale,
attn_type.value,
self.sliding_window, self.sliding_window,
self.alibi_slopes, self.alibi_slopes,
self.logits_soft_cap, self.logits_soft_cap,
@ -616,6 +693,89 @@ class FlashAttentionImpl(AttentionImpl):
return output return output
def _get_query_key_seq_metadata(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
Raises:
AttributeError: If an invalid attention type is provided.
"""
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.seq_start_loc, max_seq_len)
elif attn_type == AttentionType.ENCODER_DECODER:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER:
# For encoder attention both the query and the key are same i.e the
# encoder sequence.
return (attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: AttentionType) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
def unified_flash_attention( def unified_flash_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
@ -628,60 +788,76 @@ def unified_flash_attention(
k_scale: float, k_scale: float,
v_scale: float, v_scale: float,
softmax_scale: float, softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None, window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context() current_metadata = get_forward_context()
assert current_metadata is not None assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata) assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size) if (key is not None) and (value is not None):
value = value.view(-1, num_kv_heads, head_size) key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key_cache = kv_cache[0] key_cache = kv_cache[0]
value_cache = kv_cache[1] value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
key, key,
value, value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
attn_metadata.slot_mapping.flatten(), updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype, kv_cache_dtype,
k_scale, k_scale,
v_scale, v_scale,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_query_tokens) = \
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa decode_query = query[num_prefill_query_tokens:]
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_query_tokens]
key = key[:num_prefill_tokens] assert query.shape[0] == num_prefill_query_tokens
value = value[:num_prefill_tokens] assert decode_query.shape[0] == num_decode_query_tokens
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
@ -689,22 +865,30 @@ def unified_flash_attention(
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func( prefill_output = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len, max_seqlen_q=q_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=_get_causal_option(attn_type),
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa prefill_output = flash_attn_varlen_func( # noqa
@ -729,6 +913,8 @@ def unified_flash_attention(
# because different queries might have different lengths. # because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1: if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
decode_output = flash_attn_varlen_func( decode_output = flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
@ -746,12 +932,17 @@ def unified_flash_attention(
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache( decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), q=decode_query.unsqueeze(1),
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
block_table=decode_meta.block_tables, block_table=block_tables_arg,
cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=True,
window_size=window_size, window_size=window_size,
@ -761,10 +952,10 @@ def unified_flash_attention(
if prefill_output is None: if prefill_output is None:
assert decode_output is not None assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size) return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None: if decode_output is None:
assert prefill_output is not None assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size) return prefill_output.view(num_prefill_query_tokens, hidden_size)
# Chunked prefill does not work with speculative decoding. # Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill. # Therefore, the query length for decode should be 1 in chunked prefill.
@ -786,6 +977,7 @@ def unified_flash_attention_fake(
k_scale: float, k_scale: float,
v_scale: float, v_scale: float,
softmax_scale: float, softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None, window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,

View File

@ -1,13 +1,14 @@
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
use_cuda_graph=True, use_cuda_graph=True,
) )
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers and
# Assert the same. # Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \ assert self.runner.attn_backend.get_name() in\
f"Expected attn_backend name to be 'XFORMERS', but "\ ["XFORMERS", "FLASH_ATTN"], \
f" got '{self.runner.attn_backend.get_name()}'" f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model( self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata) batch_size=batch_size, attn_metadata=attn_metadata)
@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers and
# Assert the same. # Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \ assert self.runner.attn_backend.get_name() in\
f"Expected attn_backend name to be 'XFORMERS', but "\ ["XFORMERS", "FLASH_ATTN"], \
f" got '{self.runner.attn_backend.get_name()}'" f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model( self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers) attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers return input_buffers
@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
input_buffers["block_tables"].copy_( input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers and
# Assert the same. # Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \ assert self.runner.attn_backend.get_name() in\
f"Expected attn_backend name to be 'XFORMERS', but "\ ["XFORMERS", "FLASH_ATTN"], \
f" got '{self.runner.attn_backend.get_name()}'" f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model( self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers) attn_metadata, input_buffers)
@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
attn_metadata.encoder_seq_lens_tensor = torch.full( attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda() (batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0
def _add_additonal_input_buffers_for_enc_dec_model( def _add_additonal_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]): self, attn_metadata, input_buffers: Dict[str, Any]):
@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
input_buffers["cross_block_tables"].copy_( input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables, attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True) non_blocking=True)
def is_all_encoder_attn_metadata_set(attn_metadata):
'''
All attention metadata required for encoder attention is set.
'''
return ((attn_metadata.encoder_seq_lens is not None)
and (attn_metadata.encoder_seq_lens_tensor is not None)
and (attn_metadata.max_encoder_seq_len is not None))
def is_all_cross_attn_metadata_set(attn_metadata):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (attn_metadata.is_all_encoder_attn_metadata_set
and (attn_metadata.cross_slot_mapping is not None)
and (attn_metadata.cross_block_tables is not None))
def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: AttentionType,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
num_prefill_kv_tokens = 0
if attn_type == AttentionType.ENCODER:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = 0
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
else: # attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)

View File

@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState, from vllm.attention.backends.utils import (
CommonMetadataBuilder) CommonAttentionState, CommonMetadataBuilder,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Encoder sequence lengths representation # Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences # Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None max_encoder_seq_len: Optional[int] = None
@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
''' '''
All attention metadata required for encoder attention is set. All attention metadata required for encoder attention is set.
''' '''
return ((self.encoder_seq_lens is not None) return is_all_encoder_attn_metadata_set(self)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property @property
def is_all_cross_attn_metadata_set(self): def is_all_cross_attn_metadata_set(self):
@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
Superset of encoder attention required metadata. Superset of encoder attention required metadata.
''' '''
return (self.is_all_encoder_attn_metadata_set return is_all_cross_attn_metadata_set(self)
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["XFormersMetadata"]: def prefill_metadata(self) -> Optional["XFormersMetadata"]:
@ -329,64 +332,6 @@ def _set_attn_bias(
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_seq_len_block_table_args(
attn_metadata: XFormersMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
# No block tables associated with encoder attention
return (attn_metadata.seq_lens_tensor,
attn_metadata.max_prefill_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
_metadata_cls = XFormersMetadata _metadata_cls = XFormersMetadata
@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
updated_slot_mapping, updated_slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, v_scale) k_scale, v_scale)
(num_prefill_query_tokens, num_prefill_kv_tokens,
if attn_type == AttentionType.ENCODER: num_decode_query_tokens) = \
# Encoder attention - chunked prefill is not applicable; get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
output = torch.empty_like(query) output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_query_tokens:]
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_query_tokens]
if key is not None and value is not None: if key is not None and value is not None:
key = key[:num_encoder_tokens] key = key[:num_prefill_kv_tokens]
value = value[:num_encoder_tokens] value = value[:num_prefill_kv_tokens]
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_query_tokens
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# prefix. # prefix.
out = self._run_memory_efficient_xformers_forward( out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta, attn_type=attn_type) query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_tokens].shape assert out.shape == output[:num_prefill_query_tokens].shape
output[:num_prefill_tokens] = out output[:num_prefill_query_tokens] = out
else: else:
assert attn_type != AttentionType.ENCODER_ONLY, ( assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have prefix attention.") "Encoder-only models should not have prefix attention.")
@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
k_scale, k_scale,
v_scale, v_scale,
) )
assert output[:num_prefill_tokens].shape == out.shape assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_tokens] = out output[:num_prefill_query_tokens] = out
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, ( assert attn_type != AttentionType.ENCODER_ONLY, (
@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
seq_lens_arg, seq_lens_arg,
max_seq_len_arg, max_seq_len_arg,
block_tables_arg, block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type) ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,

View File

@ -98,7 +98,6 @@ def get_attn_backend(
is_blocksparse: bool = False, is_blocksparse: bool = False,
) -> Type[AttentionBackend]: ) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
if is_blocksparse: if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.") logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import ( from vllm.attention.backends.blocksparse_attn import (
@ -108,6 +107,7 @@ def get_attn_backend(
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free) is_attention_free)
if backend == _Backend.FLASH_ATTN: if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend)
return FlashAttentionBackend return FlashAttentionBackend

View File

@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
input_ids = input_ids.view(-1, input_ids.shape[-1])
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions( embed_pos = self.embed_positions(

View File

@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/" "currently supported with encoder/"
"decoder models.") "decoder models.")
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
"currently supported with encoder/" "backends currently supported with encoder/"
"decoder models.") "decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "

View File

@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
MultiModalRegistry) MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__) logger = init_logger(__name__)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with models) but these arguments are present here for compatibility with
the base-class constructor. the base-class constructor.
''' '''
self._maybe_force_supported_attention_backend(model_config)
self._maybe_force_supported_attention_backend()
super().__init__( super().__init__(
model_config, model_config,
parallel_config, parallel_config,
@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios # Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self) assert_enc_dec_mr_supported_scenario(self)
def _maybe_force_supported_attention_backend(self): def _is_xformers_only_encoder_decoder_model(self,
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
''' '''
Force vLLM to use the XFormers attention backend, Force vLLM to use the XFormers attention backend,
which is currently the only supported option. which is currently the only supported option.
@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None
if not (is_forced_by_global or is_forced_by_env_var): if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend # The user has not already specified an attention backend
# override # override
logger.info("EncoderDecoderModelRunner requires " logger.info(
"XFormers backend; overriding backend " "Encoder-Decoder Model Architecture %s requires XFormers "
"auto-selection and forcing XFormers.") "backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS) global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global: elif is_forced_by_global:
# Backend override enforced by global variable takes # Backend override enforced by global variable takes
# precedence over vLLM backend environment variable. # precedence over vLLM backend environment variable.
if maybe_global_forced_backend != _Backend.XFORMERS: if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err() raise_backend_err()
elif is_forced_by_env_var: elif is_forced_by_env_var:
# Backend override enforced by vLLM backend # Backend override enforced by vLLM backend
# environment variable # environment variable
if maybe_env_var_forced_backend != _Backend.XFORMERS: if maybe_env_var_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err() raise_backend_err()
def _list_to_int32_tensor( def _list_to_int32_tensor(
@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor, attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.cross_slot_mapping, attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables, attn_metadata.cross_block_tables,
) = ( ) = (
@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_seq_lens, encoder_seq_lens,
encoder_seq_lens_tensor, encoder_seq_lens_tensor,
max_encoder_seq_len, max_encoder_seq_len,
encoder_seq_start_loc,
cross_slot_mapping_tensor, cross_slot_mapping_tensor,
cross_block_tables, cross_block_tables,
) )