[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
afeldman-nm 2024-07-08 13:12:15 -04:00 committed by GitHub
parent f7a8fa39d8
commit 543aa48573
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2351 additions and 95 deletions

View File

@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]): with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type # Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window # Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed # flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}): with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size # Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch): def test_invalid_env(monkeypatch):

View File

@ -0,0 +1,953 @@
"""
Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
from typing import NamedTuple, Optional
import pytest
import torch
from tests.kernels.utils import *
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.utils import is_hip
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128]
MAX_ENC_SEQ_LENS = [128]
# Narrow teest-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
class TestPoint(NamedTuple):
"""
Encapsulates the attributes which define a single invocation
of the test_e2e_enc_dec_attn() test
Attributes:
num_heads: The number of heads in the model.
head_size: Head dimension
backend_name: Name of the backend framework used.
batch_size: Number of samples per batch.
block_size: Size of each block of data processed.
max_dec_seq_len: Maximum sequence length for the decoder.
max_enc_seq_len: Maximum sequence length for the encoder.
num_blocks: Number of blocks in the model.
"""
num_heads: int
head_size: int
backend_name: str
batch_size: int
block_size: int
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
class TestResources(NamedTuple):
'''
Encapsulates key components for performing an
encoder/decoder attention test
Note that
(1) attn automatically selects an attention backend
based on platform info & a set of canned
heuristics
(2) attn_backend is thus *not the same backend
instance* used by attn, but rather it is
intended to be a
*different instance* of the *same backend class*;
it is assumed that the user of TestResources
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementatino of abstraction
attention interface using
a particular kernel library
i.e. XFormers
* attn: Attention layer instance
* kv_cache: shared key/value cache for all attention
'''
scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor
def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
(3) This function requires that test_pt.backend_name matches the backend
class that Attention will automatically select when it is constructed.
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: num_heads, head_size, num_blocks,
block_size, backend_name
Returns:
* TestResources data structure.
'''
scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(scale, attn_backend, attn, None)
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE)
return TestResources(scale, attn_backend, attn, kv_cache)
def _encoder_attn_setup(
test_pt: TestPoint,
test_rsrcs: TestResources,
) -> PhaseTestParameters:
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
The query/key/value tensors are passed to an ideal reference
self-attention implementation to generate an ideal output tensor.
Encoder inference does not populate the KV cache, therefore
no KV cache memory mapping is constructed
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
'''
(
num_heads,
head_size,
_,
batch_size,
_,
_,
max_q_seq_len,
_,
) = test_pt
scale = test_rsrcs.scale
max_kv_seq_len = max_q_seq_len
# Make test tensors
qkv_in, _, _ = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.ENCODER,
device=CUDA_DEVICE)
# Compute correct answer using naive non-causal attention
# implementation
ideal_output = ref_masked_attention(qkv_in.query,
qkv_in.key,
qkv_in.value,
scale=scale,
q_seq_lens=qkv_in.q_seq_lens,
kv_seq_lens=qkv_in.kv_seq_lens)
packed_ideal_output, _ = pack_tensor(ideal_output,
qkv_in.q_seq_lens,
device=CUDA_DEVICE)
packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
return PhaseTestParameters(
PackedQKVO(packed_qkv, packed_ideal_output),
None # No KV cache
)
def _decoder_attn_setup(
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
'''
Set up test vectors & data structures for self-attention test.
A triplet of synthetic query/key/value tensors are constructed ("baseline"
query/key/value). Given this is a self-attention test, the key & value
sequences will have the same length as the corresponding queries.
"Prefill" query/key/value tensors are derived by masking out the last value
in each baseline query/key/value. These tensors are used to test prefill &
populate KV cache for a subsequent decode test.
"Decode" query/key/value tensors are derived by extracting *only* the last
value from each baseline query/key/value (i.e. complement of the prefill
tensors.) These tensors are used to test decode, conditional on the kv cache
being populated during the prefill test.
The baseline query/key/value tensors are passed to an ideal reference
self-attention implementation to generate a "Baseline" ideal output tensor.
This tensor is split into the "Prefill" ideal output tensor (all but the
last element of each output sequence) and the "Decode" ideal output tensor
(*only* the last element of each output sequence); the "Prefill" and
"Decode" ideal output tensors can be used to validate the prefill and decode
test results, respectively.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
decoder cross-attention block-table, which is not
constructed in this function)
'''
(
num_heads,
head_size,
_,
batch_size,
block_size,
max_q_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
max_kv_seq_len = max_q_seq_len
# Build test tensors
(
qkv,
prefill_qkv,
decode_qkv,
) = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.DECODER,
device=CUDA_DEVICE)
# Compute correct answer using naive attention implementation
# with causal attention mask
causal_mask = make_causal_mask(max_q_seq_len,
max_kv_seq_len).to(CUDA_DEVICE)
ideal_output = ref_masked_attention(qkv.query,
qkv.key,
qkv.value,
scale=scale,
custom_mask=causal_mask,
q_seq_lens=qkv.q_seq_lens,
kv_seq_lens=qkv.kv_seq_lens)
# Split out the prefill- & decode-phase ideal answers & pack them
prefill_ideal_output = torch.zeros_like(ideal_output)
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
bdx, :prefill_q_seq_len]
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
prefill_q_seq_len + 1)]
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
prefill_qkv.q_seq_lens,
device=CUDA_DEVICE)
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
[1 for _ in range(batch_size)],
device=CUDA_DEVICE)
# Build prefill- & decode-phase data structures
# for decoder self-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with entries for prompt tokens
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks
# required by total num. tokens in the entirety of all sequences
# (including both prefill & decode)
# * Slot-mapping with entries for tokens that will be decoded in the
# current decode iteration
#
# Note: the format described above is simply mirroring what ModelRunner
# produces
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
(
decode_block_tables,
slot_mapping_list,
max_block_idx,
) = make_block_tables_slot_mapping(block_size,
qkv.q_seq_lens,
device=CUDA_DEVICE,
block_base_addr=block_base_addr)
(
prefill_slot_mapping,
decode_slot_mapping,
) = split_slot_mapping(slot_mapping_list,
qkv.q_seq_lens,
device=CUDA_DEVICE)
prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
return (
qkv,
PhaseTestParameters( # Prefill test params
PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
PhaseTestParameters( # Decode test params
PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
KVMemoryMap(decode_block_tables, decode_slot_mapping)),
max_block_idx)
def _enc_dec_cross_attn_setup_reuses_query(
decoder_qkv: QKVInputs,
encoder_test_params: PhaseTestParameters,
prefill_decoder_phase_test_params: PhaseTestParameters,
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
'''
Set up test vectors & data structures for cross-attention test.
A triplet of synthetic cross-attention key/value tensors are constructed
("baseline" key/value). Given this is a cross-attention test, we assume
query tensors were already synthesized for a prior self-attention test and
will be reused for cross-attention. The key & value sequences generated here
may have a different length than the corresponding queries (as is often
the case for cross-attention between decoder and encoder sequences.)
Cross attention key & value tensors do not grow during autoregressive
inference; thus this function obtains a single key/value pair suitable for
both prefill and decode.
The "baseline" query tensor is received as an argument. The "baseline"
query/key/value tensors are passed to an ideal reference cross-attention
implementation to generate a "baseline" ideal output tensor. This tensor is
split into the "Prefill" ideal output tensor (all but the last element of
each output sequence) and the "Decode" ideal output tensor (*only* the last
element of each output sequence); the "Prefill" and "Decode" ideal output
tensors can be used to validate the prefill and decode test results,
respectively.
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
used for encoder inference; KV cache field
is not used by this function
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
used for prefill-phase decoder
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
assert encoder_test_params.packed_qkvo.packed_qkv is not None
assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
(
num_heads,
head_size,
_,
batch_size,
block_size,
max_decoder_seq_len,
max_encoder_seq_len,
_,
) = test_pt
scale = test_rsrcs.scale
decoder_query = decoder_qkv.query
decoder_seq_lens = decoder_qkv.q_seq_lens
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
prefill_q_seq_lens = (
prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
assert prefill_q_seq_lens is not None
(
cross_kv,
_,
_,
) = make_qkv(batch_size,
max_decoder_seq_len,
max_encoder_seq_len,
num_heads,
head_size,
force_kv_seq_lens=encoder_seq_lens,
attn_type=AttentionType.ENCODER_DECODER,
device=CUDA_DEVICE)
ideal_output = ref_masked_attention(decoder_query,
cross_kv.key,
cross_kv.value,
scale=scale,
q_seq_lens=decoder_seq_lens,
kv_seq_lens=cross_kv.kv_seq_lens)
prefill_ideal_output = torch.zeros_like(ideal_output)
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
bdx, :prefill_q_seq_len]
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
prefill_q_seq_len + 1)]
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
prefill_q_seq_lens,
device=CUDA_DEVICE)
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
[1 for _ in range(batch_size)],
device=CUDA_DEVICE)
# Build prefill- & decode-phase data structures
# for encoder/decoder cross-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Whereas decoder self-attention extracts relationships between
# equal-length Q/K/V sequences, which mutually grow in length
# with each decoded token, cross-attention relates the Q sequence
# - which grows with each new decoded token - to fixed-length
# K and V sequences derived from the encoder hidden states.
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with as many entries as there are tokens in the encoder
# prompt.
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks to
# accommodate K & V tensors which are equal in lnegth
# to the encoder prompt length
# * Empty slot-mapping tensor (since K & V are fixed in size,
# new decoded tokens are not KV-cached and require no slot-
# mapping)
#
# Note: the format above is simply an extension of what ModelRunner
# produces for decoder-only models
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
(
decode_block_tables,
prefill_slot_mapping_list,
_,
) = make_block_tables_slot_mapping(block_size,
cross_kv.kv_seq_lens,
block_base_addr=block_base_addr,
device=CUDA_DEVICE)
prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
device=CUDA_DEVICE)
# Packed key/value (query is already provided)
packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
return (
PhaseTestParameters( # Prefill-phase test params
PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
PhaseTestParameters( # Decode-phase test params
PackedQKVO(None, decode_packed_ideal_output),
KVMemoryMap(decode_block_tables, decode_slot_mapping)))
def _run_encoder_attention_test(
attn: Attention,
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
(There is no encoder execution in the decode-phase)
Arguments:
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed {query,key,value} and
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
None,
attn_metadata,
attn_type=attn_type)
def _run_decoder_self_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run decoder self-attention test.
attn.forward() is passed attn_type=AttentionType.DECODER
in order to configure the kernel invocation for decoder self-attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
def _run_encoder_decoder_cross_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
Via PhaseTestParameters data structures, consumes the same query utilized
for decoder self-attention, plus a key/value specific to cross-attention.
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
is None, this reflects that in decode-phase cross attention there
is no growth in the key and value tensors.
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
in order to configure the kernel invocation for encoder/decoder cross-
attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
key = None
value = None
else:
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
batch_size: int, block_size: int, max_dec_seq_len: int,
max_enc_seq_len: int, monkeypatch):
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
None,
decoder_test_params=None,
encoder_test_params=enc_test_params,
cross_test_params=None,
device=CUDA_DEVICE)
# PREFILL: encoder attention
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_e2e_enc_dec_attn(
num_heads: int,
head_size: int,
backend_name: str,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
monkeypatch,
) -> None:
'''
End-to-end encoder/decoder test:
* Construct fake test vectors for (1) encoder attention,
(2) decoder self-attention, and (3) encoder/decoder cross-attention
* Construct (1) attention metadata structure with self- and cross-attention
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
* Validate output correctness against ideal reference attention
implementation
Block tables are constructed such that cross-attention KV cache is in a
higher, non-intersecting address-space than self-attention KV cache.
Self- and cross-attention share the same query tensor but not the K/V
tensors. Self-attention K/Vs must have the same seq len as Q while
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
(
dec_qkv,
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
(
prephase_cross_test_params,
decphase_cross_test_params,
) = _enc_dec_cross_attn_setup_reuses_query(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=prephase_cross_test_params,
device=CUDA_DEVICE)
# PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
# DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=decphase_cross_test_params,
device=CUDA_DEVICE)
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)

View File

@ -1,12 +1,211 @@
"""Kernel test utils""" """Kernel test utils"""
import pytest import itertools
import random
from numbers import Number
from typing import Any, List, NamedTuple, Optional, Tuple, Union
import pytest
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import make_tensor_with_pad
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
class QKVInputs(NamedTuple):
'''
Data structure for representing unpacked attention inputs,
query/key/values and their sequence lengths.
Attributes:
* {query,key,value}: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_seq_lens: List[int]
kv_seq_lens: List[int]
class QKVO(NamedTuple):
'''
Data structure for representing unpacked attention inputs,
alongside unpacked known-correct attention output
Attributes:
* qkv: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* ideal_output: unpacked (batch_size x padded_seq_len x
num_heads x head_size) known-correct attention output
'''
qkv: QKVInputs
ideal_output: torch.Tensor
class PackedQKVInputs(NamedTuple):
'''
Data structure for representing packed attention inputs
Attributes:
* {query,key,value}: packed (number_of_tokens x num_heads
x head_size) attention inputs
* q_start_loc_list: list of query start locations within packed tensor
* kv_start_loc_list: shared list of key/value start locations within
packed tensor
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[List[int]]
kv_start_loc_list: Optional[List[int]]
q_seq_lens: Optional[List[int]]
kv_seq_lens: Optional[List[int]]
class PackedQKVO(NamedTuple):
'''
Data structure for representing packed attention inputs,
alongside packed known-correct attention output
Attributes:
* packed_qkv: packed (number_of_tokens x num_heads
x head_size) attention inputs
* ideal_output: packed (number_of_tokens x num_heads
x head_size) known-correct attention output
'''
packed_qkv: Optional[PackedQKVInputs]
ideal_output: torch.Tensor
class KVMemoryMap(NamedTuple):
'''
Data structure for encapsulating KV cache memory mapping.
Attributes:
* block_tables: KV cache block tables
* slot_mapping: mapping of sequence offset to physical address
'''
block_tables: torch.Tensor
slot_mapping: torch.Tensor
class PhaseTestParameters(NamedTuple):
'''
Data structure for encapsulating the test parameters
for a given test "phase" (prefill or decode phase) and attention
scenario (encoder, decoder-self, encoder/decoder-cross)
Attributes:
* packed_qkvo: packed (number_of_tokens x num_heads
x head_size) attention inputs & known-correct
output
* kv_mmap: KV cache memory mapping, specific to this test phase &
attention scenario
'''
packed_qkvo: PackedQKVO
kv_mmap: Optional[KVMemoryMap]
def maybe_make_int_tensor(
_list: Optional[List[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
Convert Python int list to a 1D int torch.Tensor on `device`
Returns:
* If _list is not None: 1D int torch.Tensor on `device`
* None otherwise
'''
return None if _list is None else torch.tensor(
_list, dtype=torch.int, device=device)
def maybe_make_long_tensor(
_list: Optional[List[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
Convert Python int list to a 1D long torch.Tensor on `device`
Returns:
* If _list is not None: 1D long torch.Tensor on `device`
* None otherwise
'''
return None if _list is None else torch.tensor(
_list, dtype=torch.long, device=device)
def maybe_max(_list: Optional[List]) -> Optional[Number]:
'''
Returns:
* If _list is not None: max(_list)
* None otherwise
'''
return None if _list is None else max(_list)
def make_causal_mask(
q_max_seq_len: int,
kv_max_seq_len: int,
) -> torch.Tensor:
'''
Create a q_max_seq_len x kv_max_seq_len causal mask
Arguments:
* q_max_seq_len: query max seq len
* kv_max_seq_len: key/value max seq len
Returns:
* 2D tensor, q_max_seq_len x kv_max_seq_len
'''
# Create a matrix where entry (i, j) is True if i >= j
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
# Replace True with float('-inf') and False with 0
mask = mask.masked_fill(mask == 1,
float('-inf')).masked_fill(mask == 0, 0.0)
return mask
def override_backend_env_variable(mpatch: pytest.MonkeyPatch, def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None: backend_name: str) -> None:
''' '''
@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
* backend_name: attention backend name to force * backend_name: attention backend name to force
''' '''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
def ref_masked_attention(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
'''
"Golden" masked attention reference. Supports two types of masking:
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
padding elements
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
causal
Arguments:
* query: batch_size x q_padded_seq_len x num_heads x head_size
* key: batch_size x kv_padded_seq_len x num_heads x head_size
* value: batch_size x kv_padded_seq_len x num_heads x head_size
* scale: Attention scale factor
* custom_mask: custom attention mask; good place to inject a causal
attention mask
* q_seq_lens: list of unpadded query seq_lens for each batch index
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
Returns:
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
'''
assert q_seq_lens is not None
assert kv_seq_lens is not None
batch_size = query.shape[0]
assert (len(q_seq_lens) == batch_size)
assert (len(kv_seq_lens) == batch_size)
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
# Basic attention mask, derived from seq lens
if (q_seq_lens is not None) or (kv_seq_lens is not None):
attn_mask = torch.zeros_like(attn_weights)
if q_seq_lens is not None:
for bdx, plen in enumerate(q_seq_lens):
attn_mask[bdx, :, plen:, :] = -torch.inf
if kv_seq_lens is not None:
for bdx, plen in enumerate(kv_seq_lens):
attn_mask[bdx, :, :, plen:] = -torch.inf
attn_weights = attn_weights + attn_mask.float()
# Custom attention mask
if custom_mask is not None:
attn_weights = attn_weights + custom_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
return out
def make_qkv(
batch_size: int,
max_q_seq_len: int,
max_kv_seq_len: Optional[int],
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[List[int]] = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
'''
Construct QKV test tensors for self- and cross-attention.
Generates three query/key/value triplets:
* "Baseline" query/key/value (for input to reference attention function)
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
input to prefill kernel)
* "Decode" query/key/value (only the last sequence offset from baseline,
for use as input to decode kernel)
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
seqlens
Arguments:
* batch_size
* max_q_seq_len: max query seq len
* max_kv_seq_len: max key/value seq len
* num_heads
* head_size
* is_encoder_decoder_attn: if True, query seqlen may differ from
key/value seqlen (as is often the case for cross-attention);
o/w, query/key/value seqlens match at each batch index
(max_kv_seq_len is unused)
* force_kv_seq_lens: if not None, overrides kv sequence lengths
* attn_type: encoder, decoder self, or enc/dec cross attention
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
* device: CPU or CUDA device
Returns:
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
* Prefill QKVInputs structure (containing all but the last sequence offset)
* Decode QKVInputs structure (containing all only the last sequence offset)
'''
if force_max_len:
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
else:
q_seq_lens = [
random.randint(2, max_q_seq_len) for _ in range(batch_size)
]
kv_seq_lens = None
if force_kv_seq_lens is not None:
kv_seq_lens = force_kv_seq_lens
elif attn_type != AttentionType.ENCODER_DECODER:
# K,V seq lens match Q for self-attention
kv_seq_lens = q_seq_lens
else:
# K,V seq lens are distinct from Q seq lens & random
assert max_kv_seq_len is not None
if force_max_len:
kv_seq_lens = [max_kv_seq_len] * batch_size
else:
kv_seq_lens = [
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
]
query = torch.rand(
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
key = torch.rand(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
value = torch.rand(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
prefill_query = torch.zeros(
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
prefill_key = torch.zeros(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
prefill_value = torch.zeros(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
decode_query = torch.zeros(
(batch_size, 1, num_heads, head_size)).to(device)
decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
decode_value = torch.zeros(
(batch_size, 1, num_heads, head_size)).to(device)
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
kv_seq_lens)):
query[bdx, q_seq_len:, :, :] = 0
key[bdx, kv_seq_len:, :, :] = 0
value[bdx, kv_seq_len:, :, :] = 0
prefill_query[bdx,
0:(q_seq_len - 1), :, :] = query[bdx,
0:(q_seq_len - 1), :, :]
prefill_key[bdx,
0:(kv_seq_len - 1), :, :] = key[bdx,
0:(kv_seq_len - 1), :, :]
prefill_value[bdx, 0:(kv_seq_len -
1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
decode_query[bdx, :, :, :] = query[bdx,
(q_seq_len - 1):q_seq_len, :, :]
decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
decode_value[bdx, :, :, :] = value[bdx,
(kv_seq_len - 1):kv_seq_len, :, :]
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
decode_q_seq_lens = [1 for _ in q_seq_lens]
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
return (
QKVInputs(
query, # Overall QKV inputs
key,
value,
q_seq_lens,
kv_seq_lens),
QKVInputs(
prefill_query, # Prefill subset of QKV sequences
prefill_key,
prefill_value,
prefill_q_seq_lens,
prefill_kv_seq_lens),
QKVInputs(
decode_query, # Decode subset of KV sequences
decode_key,
decode_value,
decode_q_seq_lens,
decode_kv_seq_lens))
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
number_of_tokens = sum(seq_lens)
Arguments:
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
* seq_lens: list of token counts for each seq
* device: CPU or CUDA device
Returns
* packed_tensor: number_of_tokens x num_heads x head_size
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
list(itertools.accumulate(seq_lens))
'''
num_tok = sum(seq_lens)
num_heads = unpacked_tensor.shape[-2]
head_size = unpacked_tensor.shape[-1]
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
packed_tensor[start_loc:(
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
return packed_tensor, start_loc_list
def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
str]) -> PackedQKVInputs:
'''
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
num_heads x head_size tensors.
For Q, number_of_tokens = sum(q_seq_lens).
For K and V, number_of_tokens = sum(kv_seq_lens)
Arguments:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
attention inputs
* device: CPU or CUDA device
Returns
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
derived from unpacked inputs
'''
if qkv.query is None:
packed_query = None
q_start_loc_list = None
else:
packed_query, q_start_loc_list = pack_tensor(qkv.query,
qkv.q_seq_lens,
device=device)
packed_key, kv_start_loc_list = pack_tensor(qkv.key,
qkv.kv_seq_lens,
device=device)
packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
return PackedQKVInputs(
packed_query, packed_key, packed_value, q_start_loc_list,
kv_start_loc_list,
(None if q_start_loc_list is None else qkv.q_seq_lens),
qkv.kv_seq_lens)
def make_backend(backend_name: str) -> AttentionBackend:
'''
Construct the backend instance determined by the backend_name string
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
inference, but rather for generating compatible metadata structures
using backend.make_metadata()
Returns:
* Backend instance
'''
if backend_name == STR_XFORMERS_ATTN_VAL:
return XFormersBackend()
raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")
def _make_metadata_tensors(
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
Arguments:
* seq_lens: list of token-counts for each decoder input seq
* context_lens: list of context length values for each seq
* encoder_seq_lens: list of token-counts for each encoder input seq
* device: CPU or CUDA device
Returns:
* seq_lens_tensor: decoder seq_lens list, as tensor
* context_lens_tensor: context_lens list, as tensor
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
context_lens_tensor = maybe_make_int_tensor(context_lens, device)
max_context_len = maybe_max(context_lens)
max_seq_len = maybe_max(seq_lens)
encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
max_encoder_seq_len = (None if encoder_seq_lens is None else
max(encoder_seq_lens))
seq_start_loc = None
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
def make_kv_cache(num_blocks: int,
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
default_val: float = 0.0) -> torch.Tensor:
'''
Create a fake KV cache.
Arguments:
* num_blocks: number of blocks in the KV cache
* num_heads: number of attention heads
* head_size: head dimension
* block_size: number of offsets within a block
* device: CPU or CUDA device
* default_val: initialization value for KV cache elements
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
'''
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
'''
Compute the minimum number of blocks required to hold num_tokens tokens,
given block_size
'''
return (num_tokens + block_size) // block_size
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
return maybe_make_long_tensor([], device)
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return torch.tensor([], device=device)
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
Context:
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
for all N prompts (N tokens total); the resultant sequence lengths
after decode would be {K_i + 1 for i \\in [0,N)}
* The test you want to do requires (1) having the prefill slot mapping
for all tokens present during prefill, the number of which is
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
decoded tokens
This function consumes a single 1D slot mapping, which is the
concatenation of N slot mappings each of length K_i + 1 (corresponding
to the sequence lengths after decode), with a total length of
P = \\sum_i{K_i + 1} = M + N
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
from each of the N subsequences in the slot mapping (i.e. omitting the
decoded token's mapping.)
The N excised entries are appended to obtain the decode-phase slot mapping
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
Returns:
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
reflecting all N prefill prompts
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
all N decoded tokens
'''
prefill_slot_mapping = []
decode_slot_mapping = []
base_idx = 0
for seq_len in seq_lens:
prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
seq_len - 1)])
decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
base_idx += seq_len
return (maybe_make_long_tensor(prefill_slot_mapping, device),
maybe_make_long_tensor(decode_slot_mapping, device))
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: List[int],
device: Union[torch.device, str],
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
'''
Construct fake block tables & slot mappings.
For a sequence with num_tokens tokens the minimum number
of required KV cache blocks is
num_blocks = (num_tokens + block_size) // block_size
Then the minimum KV cache size in blocks is
total_cache_blocks = sum(num_blocks for all seqs)
Then, the blocktable mapping counts downward from
block_base_addr + total_cache_blocks
to
block_base_addr
The constructed block-tables and slot-mapping are sized to the
lengths of the sequences in their entirety (as reflected by seq_lens),
i.e. the total of prefill prompt tokens + decoded tokens.
Arguments:
* block_size: number of offsets per block
* seq_lens: list of token-counts for each sequence
* block_base_addr: the block table base address
* device: CPU or CUDA device
Return:
* block_tables_tensor: block table for sequence
* slot_mapping_list: slot mapping for sequence
* max_block_idx: the highest block address within this block table
'''
# Provision minimum number of KV cache blocks
num_blocks_list = [
_num_tokens_to_min_blocks(num_tokens, block_size)
for num_tokens in seq_lens
]
max_block_table_len = max(num_blocks_list)
block_table_pad_tokens = 10
block_tables = []
slot_mapping_list = []
# Compute uppermost address of block table
total_cache_blocks = sum(num_blocks_list)
block_base_idx = block_base_addr + total_cache_blocks
max_block_idx = block_base_idx
for sdx, num_tokens in enumerate(seq_lens):
num_blocks = num_blocks_list[sdx]
block_table = list(
range(block_base_idx, block_base_idx - num_blocks, -1))
for idx in range(num_tokens):
mapping_value = (
idx % block_size) + block_table[idx // block_size] * block_size
slot_mapping_list.append(mapping_value)
block_base_idx -= num_blocks
block_tables.append(block_table)
block_tables_tensor = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len + block_table_pad_tokens,
pad=0,
dtype=torch.int,
device=device,
)
return (block_tables_tensor, slot_mapping_list, max_block_idx)
def make_test_metadata(
attn_backend: AttentionBackend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
cross_test_params: Optional[PhaseTestParameters] = None
) -> AttentionMetadata:
'''
Construct fake attention metadata for a given test phase
(prefill-phase or decode-phase).
encoder_test_params and cross_test_params arguments allow encoder
attention and enc/dec cross-attention (respectively) to use distinct
metadata values from decoder self-attention (decoder_test_params.)
if encoder_test_params and cross_test_params are None, the attention
metadata will support decoder-only scenario.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
this function requires
kv_mmap (memory mapping) field
* device: CPU or CUDA device
* encoder_test_params: encoder attention test params;
this function requires encoder query
sequence lengths field. If None,
encoder query sequence lengths are
treated as None
* cross_test_params: enc/dec cross-attention test params;
this function requires kv_mmap field.
If None, KV cache memory map data
structures are treated as None
Return:
* AttentionMetadata structure
'''
# Decoder self-attention memory mapping
# decoder_test_params is None signals encoder-only
# scenario, so kv_mmap is None
kv_mmap = (None
if decoder_test_params is None else decoder_test_params.kv_mmap)
# This function constructs metadata assuming no chunked prefill,
# i.e. 100% prefill tokens or 100% decode tokens
#
# - If is_prompt, num_prefills_or_decodes is the number of prefills
# and num_prefill_or_decode_tokens is the number of prefill tokens
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
# and num_prefill_or_decode_tokens is the number of decode tokens
#
# seq_lens is None signals encoder-only
# scenario, in which case num_prefills_or_decodes and
# num_prefill_or_decode_tokens are unused
num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
num_prefill_or_decode_tokens = (None if seq_lens is None else (
sum(seq_lens) if is_prompt else len(seq_lens)))
# Seems for non-prefix-caching scenarios context_lens
# is never needed
context_lens = None
if encoder_test_params is None:
encoder_seq_lens = None
num_encoder_tokens = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract encoder input sequence lengths
assert encoder_test_params.packed_qkvo.packed_qkv is not None
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
num_encoder_tokens = (None if encoder_seq_lens is None else
(sum(encoder_seq_lens)))
if cross_test_params is None:
cross_kv_mmap = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
if is_prompt:
# Prefill-phase scenario
num_prefills = num_prefills_or_decodes
num_prefill_tokens = num_prefill_or_decode_tokens
num_decode_tokens = 0
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
_,
encoder_seq_lens_tensor,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor,
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
else: # not is_prompt
# Decode-phase scenario
assert kv_mmap is not None
assert num_prefill_or_decode_tokens is not None
assert seq_lens is not None
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = num_prefill_or_decode_tokens
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
_,
encoder_seq_lens_tensor,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables,
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test: torch.Tensor) -> None:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
Arguments:
* test_params: Test parameters including packed ideal output
* output_under_test: actually observed output value
'''
ideal_output = test_params.packed_qkvo.ideal_output
assert torch.allclose(ideal_output,
output_under_test.view_as(ideal_output))

View File

@ -1,11 +1,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar) TypeVar)
import torch import torch
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step) LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."

View File

@ -14,7 +14,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
@ -224,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)

View File

@ -7,7 +7,7 @@ import torch
from vllm._ipex_ops import ipex_ops from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -157,6 +157,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
@ -170,6 +171,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@ -6,7 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@ -132,6 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
@ -146,6 +147,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

View File

@ -6,7 +6,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -297,6 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -309,6 +310,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@ -7,7 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu from vllm.utils import is_cpu
@ -145,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -158,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@ -0,0 +1,7 @@
"""Attention backend utils"""
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")

View File

@ -6,10 +6,11 @@ import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias, from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask, BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API. updated from `CUDAGraphRunner.forward` API.
""" """
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------| # |-------------------- seq_len ----------------------|
# |-- query_len ---| # |-- query_len ---|
# Maximum query length in the batch. None for decoding. # seq_lens stored as a tensor.
max_query_len: Optional[int] seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn. # FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum sequence length among decode batch. 0 if there are prefill # Maximum sequence length among decode batch. 0 if there are prefill
# requests only. # requests only.
max_decode_seq_len: int max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled. # Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersMetadata"] = None _cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None _cached_decode_metadata: Optional["XFormersMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API. # from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["XFormersMetadata"]: def prefill_metadata(self) -> Optional["XFormersMetadata"]:
@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None return None
if self._cached_prefill_metadata is not None: if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata return self._cached_prefill_metadata
assert self.seq_lens is not None assert ((self.seq_lens is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens is not None))
assert self.query_start_loc is not None assert ((self.seq_lens_tensor is not None)
assert self.context_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
assert self.block_tables is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersMetadata( self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=seq_lens,
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=query_start_loc,
seq_start_loc=None, context_lens_tensor=context_lens_tensor,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=block_tables,
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None return None
if self._cached_decode_metadata is not None: if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata return self._cached_decode_metadata
assert self.block_tables is not None assert ((self.seq_lens_tensor is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata( self._cached_decode_metadata = XFormersMetadata(
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=slot_mapping,
seq_lens=None, seq_lens_tensor=seq_lens_tensor,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, block_tables=block_tables,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata return self._cached_decode_metadata
def _get_attn_bias(
attn_metadata: XFormersMetadata,
attn_type: AttentionType,
) -> Optional[AttentionBias]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
else:
# attn_type == AttentionType.ENCODER_DECODER
return attn_metadata.cross_attn_bias
def _set_attn_bias(
attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
attn_metadata.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_seq_len_block_table_args(
attn_metadata: XFormersMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersImpl(AttentionImpl[XFormersMetadata]): class XFormersImpl(AttentionImpl[XFormersMetadata]):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: Optional[torch.Tensor],
value: torch.Tensor, value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: # Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache. if (key is not None) and (value is not None):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens if attn_type == AttentionType.ENCODER_DECODER:
num_decode_tokens = attn_metadata.num_decode_tokens # Update cross-attention KV cache (prefill-only)
assert key.shape[0] == num_prefill_tokens + num_decode_tokens # During cross-attention decode, key & value will be None,
assert value.shape[0] == num_prefill_tokens + num_decode_tokens # preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
kv_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query) output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens] if key is not None and value is not None:
value = value[:num_prefill_tokens] key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# block tables are empty if the prompt does not have a cached # block tables are empty if the prompt does not have a cached
# prefix. # prefix.
out = self._run_memory_efficient_xformers_forward( out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta) query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_tokens].shape assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
else: else:
assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, block_tables_arg,
decode_meta.seq_lens_tensor, seq_lens_arg,
decode_meta.max_decode_seq_len, max_seq_len_arg,
self.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_prefill_tokens, num_kv_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
""" """
assert attn_metadata.seq_lens is not None
original_query = query original_query = query
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K]. # GQA/MQA requires the shape [B, M, G, H, K].
@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
if attn_metadata.attn_bias is None: attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None: if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens( if (attn_type == AttentionType.ENCODER_DECODER):
attn_metadata.seq_lens) assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None
# Default enc/dec cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
elif attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
# Default encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
else:
assert attn_metadata.seq_lens is not None
# Default decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention( attn_bias = attn_bias.make_local_attention(
self.sliding_window) self.sliding_window)
attn_metadata.attn_bias = [attn_bias] attn_bias = [attn_bias]
else: else:
attn_metadata.attn_bias = _make_alibi_bias( assert attn_metadata.seq_lens is not None
self.alibi_slopes, self.num_kv_heads, query.dtype, attn_bias = _make_alibi_bias(self.alibi_slopes,
attn_metadata.seq_lens) self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
_set_attn_bias(attn_metadata, attn_bias, attn_type)
# No alibi slopes. # No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query, query,
key, key,
value, value,
attn_bias=attn_metadata.attn_bias[0], attn_bias=attn_bias[0],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
return out.view_as(original_query) return out.view_as(original_query)
@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# FIXME(woosuk): Because xformers does not support dynamic sequence # FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
assert attn_metadata.seq_lens is not None
output = torch.empty_like(original_query) output = torch.empty_like(original_query)
start = 0 start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens): for i, seq_len in enumerate(attn_metadata.seq_lens):
@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query[None, start:end], query[None, start:end],
key[None, start:end], key[None, start:end],
value[None, start:end], value[None, start:end],
attn_bias=attn_metadata.attn_bias[i], attn_bias=attn_bias[i],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -90,9 +90,16 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._kv_scale) return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._kv_scale,
attn_type=attn_type)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore