mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
f7a8fa39d8
commit
543aa48573
@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported sliding window
|
||||
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported head size
|
||||
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch):
|
||||
|
||||
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
@ -0,0 +1,953 @@
|
||||
"""
|
||||
Tests:
|
||||
|
||||
* E2E test of Encoder attention + Decoder self-attention +
|
||||
Encoder/decoder cross-attention (collectively
|
||||
"encoder/decoder attention")
|
||||
* Confirm enc/dec models will fail for chunked prefill
|
||||
* Confirm enc/dec models will fail for prefix caching
|
||||
|
||||
"""
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import *
|
||||
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.utils import is_hip
|
||||
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
BLOCK_SIZES = [16]
|
||||
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
||||
CUDA_DEVICE = "cuda:0"
|
||||
|
||||
MAX_DEC_SEQ_LENS = [128]
|
||||
MAX_ENC_SEQ_LENS = [128]
|
||||
|
||||
# Narrow teest-cases for unsupported-scenario
|
||||
# tests
|
||||
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
|
||||
|
||||
|
||||
class TestPoint(NamedTuple):
|
||||
"""
|
||||
Encapsulates the attributes which define a single invocation
|
||||
of the test_e2e_enc_dec_attn() test
|
||||
|
||||
Attributes:
|
||||
num_heads: The number of heads in the model.
|
||||
head_size: Head dimension
|
||||
backend_name: Name of the backend framework used.
|
||||
batch_size: Number of samples per batch.
|
||||
block_size: Size of each block of data processed.
|
||||
max_dec_seq_len: Maximum sequence length for the decoder.
|
||||
max_enc_seq_len: Maximum sequence length for the encoder.
|
||||
num_blocks: Number of blocks in the model.
|
||||
"""
|
||||
|
||||
num_heads: int
|
||||
head_size: int
|
||||
backend_name: str
|
||||
batch_size: int
|
||||
block_size: int
|
||||
max_dec_seq_len: int
|
||||
max_enc_seq_len: int
|
||||
num_blocks: int
|
||||
|
||||
|
||||
class TestResources(NamedTuple):
|
||||
'''
|
||||
Encapsulates key components for performing an
|
||||
encoder/decoder attention test
|
||||
|
||||
Note that
|
||||
(1) attn automatically selects an attention backend
|
||||
based on platform info & a set of canned
|
||||
heuristics
|
||||
(2) attn_backend is thus *not the same backend
|
||||
instance* used by attn, but rather it is
|
||||
intended to be a
|
||||
*different instance* of the *same backend class*;
|
||||
it is assumed that the user of TestResources
|
||||
will leverage attn_backend for the purpose of
|
||||
constructing backend-compatible attention
|
||||
metadata instances
|
||||
|
||||
Attributes:
|
||||
|
||||
* scale: 1/sqrt(d) scale factor for attn
|
||||
* attn_backend: implementatino of abstraction
|
||||
attention interface using
|
||||
a particular kernel library
|
||||
i.e. XFormers
|
||||
* attn: Attention layer instance
|
||||
* kv_cache: shared key/value cache for all attention
|
||||
'''
|
||||
|
||||
scale: float
|
||||
attn_backend: AttentionBackend
|
||||
attn: Attention
|
||||
kv_cache: torch.Tensor
|
||||
|
||||
|
||||
def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
'''
|
||||
Build key components for performing encoder/decoder attention test.
|
||||
|
||||
Note that
|
||||
(1) The Attention instance constructed here, automatically selects
|
||||
an attention backend class based on platform info & a set of canned
|
||||
heuristics, so
|
||||
(2) The attention backend instance constructed here is thus *not
|
||||
the same backend instance* used by attn, but rather it is
|
||||
intended to be a *different instance* of the *same backend class*;
|
||||
therefore,
|
||||
(3) This function requires that test_pt.backend_name matches the backend
|
||||
class that Attention will automatically select when it is constructed.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: num_heads, head_size, num_blocks,
|
||||
block_size, backend_name
|
||||
|
||||
Returns:
|
||||
|
||||
* TestResources data structure.
|
||||
'''
|
||||
|
||||
scale = float(1.0 / (test_pt.head_size**0.5))
|
||||
attn_backend = make_backend(test_pt.backend_name)
|
||||
attn = Attention(
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
scale=scale,
|
||||
)
|
||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||
# Caller does not require a KV cache
|
||||
return TestResources(scale, attn_backend, attn, None)
|
||||
|
||||
# Construct KV cache
|
||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE)
|
||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||
|
||||
|
||||
def _encoder_attn_setup(
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
) -> PhaseTestParameters:
|
||||
'''
|
||||
Set up test vectors & data structures for encoder attention test.
|
||||
|
||||
A triplet of synthetic query/key/value tensors are constructed.
|
||||
Given this is an encoder attention test, the key & value
|
||||
sequences will have the same length as the corresponding queries.
|
||||
|
||||
The query/key/value tensors are passed to an ideal reference
|
||||
self-attention implementation to generate an ideal output tensor.
|
||||
|
||||
Encoder inference does not populate the KV cache, therefore
|
||||
no KV cache memory mapping is constructed
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
* PhaseTestParameters data structure comprising (1) packed query/key/value
|
||||
tensors, (2) the ideal output of attention computed using a naive
|
||||
implementation, and (3) KVCache field set to None
|
||||
'''
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
_,
|
||||
_,
|
||||
max_q_seq_len,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
max_kv_seq_len = max_q_seq_len
|
||||
|
||||
# Make test tensors
|
||||
|
||||
qkv_in, _, _ = make_qkv(batch_size,
|
||||
max_q_seq_len,
|
||||
max_kv_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Compute correct answer using naive non-causal attention
|
||||
# implementation
|
||||
|
||||
ideal_output = ref_masked_attention(qkv_in.query,
|
||||
qkv_in.key,
|
||||
qkv_in.value,
|
||||
scale=scale,
|
||||
q_seq_lens=qkv_in.q_seq_lens,
|
||||
kv_seq_lens=qkv_in.kv_seq_lens)
|
||||
|
||||
packed_ideal_output, _ = pack_tensor(ideal_output,
|
||||
qkv_in.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
|
||||
|
||||
return PhaseTestParameters(
|
||||
PackedQKVO(packed_qkv, packed_ideal_output),
|
||||
None # No KV cache
|
||||
)
|
||||
|
||||
|
||||
def _decoder_attn_setup(
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
block_base_addr: int = 0,
|
||||
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
|
||||
'''
|
||||
Set up test vectors & data structures for self-attention test.
|
||||
|
||||
A triplet of synthetic query/key/value tensors are constructed ("baseline"
|
||||
query/key/value). Given this is a self-attention test, the key & value
|
||||
sequences will have the same length as the corresponding queries.
|
||||
|
||||
"Prefill" query/key/value tensors are derived by masking out the last value
|
||||
in each baseline query/key/value. These tensors are used to test prefill &
|
||||
populate KV cache for a subsequent decode test.
|
||||
|
||||
"Decode" query/key/value tensors are derived by extracting *only* the last
|
||||
value from each baseline query/key/value (i.e. complement of the prefill
|
||||
tensors.) These tensors are used to test decode, conditional on the kv cache
|
||||
being populated during the prefill test.
|
||||
|
||||
The baseline query/key/value tensors are passed to an ideal reference
|
||||
self-attention implementation to generate a "Baseline" ideal output tensor.
|
||||
This tensor is split into the "Prefill" ideal output tensor (all but the
|
||||
last element of each output sequence) and the "Decode" ideal output tensor
|
||||
(*only* the last element of each output sequence); the "Prefill" and
|
||||
"Decode" ideal output tensors can be used to validate the prefill and decode
|
||||
test results, respectively.
|
||||
|
||||
This function also constructs the self-attention KV cache memory mapping
|
||||
(slot mapping and block table), ensuring that the block table starts at
|
||||
block_base_addr
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
* block_base_addr: decoder self-attention block-table base address
|
||||
|
||||
Returns:
|
||||
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
|
||||
head_size) query/key/value tensors
|
||||
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
|
||||
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||
query/key/value tensors along with (2) ideal attention output
|
||||
computed using a naive implementation, and (3) memory-mapping data
|
||||
structures appropriate for prefill phase.
|
||||
* Decode-phase decoder self-attention PhaseTestParameters data structure,
|
||||
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||
query/key/value tensors along with (2) ideal attention output
|
||||
computed using a naive implementation, and (3) memory-mapping data
|
||||
structures appropriate for decode phase.
|
||||
* max_block_idx: max physical address in decoder self-attention block-table
|
||||
(intended to be used as the base address for the encoder/
|
||||
decoder cross-attention block-table, which is not
|
||||
constructed in this function)
|
||||
'''
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
block_size,
|
||||
max_q_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
max_kv_seq_len = max_q_seq_len
|
||||
|
||||
# Build test tensors
|
||||
|
||||
(
|
||||
qkv,
|
||||
prefill_qkv,
|
||||
decode_qkv,
|
||||
) = make_qkv(batch_size,
|
||||
max_q_seq_len,
|
||||
max_kv_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
attn_type=AttentionType.DECODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Compute correct answer using naive attention implementation
|
||||
# with causal attention mask
|
||||
|
||||
causal_mask = make_causal_mask(max_q_seq_len,
|
||||
max_kv_seq_len).to(CUDA_DEVICE)
|
||||
|
||||
ideal_output = ref_masked_attention(qkv.query,
|
||||
qkv.key,
|
||||
qkv.value,
|
||||
scale=scale,
|
||||
custom_mask=causal_mask,
|
||||
q_seq_lens=qkv.q_seq_lens,
|
||||
kv_seq_lens=qkv.kv_seq_lens)
|
||||
|
||||
# Split out the prefill- & decode-phase ideal answers & pack them
|
||||
|
||||
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||
for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
|
||||
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||
bdx, :prefill_q_seq_len]
|
||||
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||
prefill_q_seq_len + 1)]
|
||||
|
||||
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||
prefill_qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||
[1 for _ in range(batch_size)],
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Build prefill- & decode-phase data structures
|
||||
# for decoder self-attention. Block tables and
|
||||
# slot mapping must be in a format compatible
|
||||
# with KV caching & attention kernels
|
||||
#
|
||||
# Prefill-phase:
|
||||
#
|
||||
# * Empty block-tables tensor
|
||||
# * Slot-mapping with entries for prompt tokens
|
||||
#
|
||||
# Decode-phase:
|
||||
# * Block-tables tensor with minimum number of blocks
|
||||
# required by total num. tokens in the entirety of all sequences
|
||||
# (including both prefill & decode)
|
||||
# * Slot-mapping with entries for tokens that will be decoded in the
|
||||
# current decode iteration
|
||||
#
|
||||
# Note: the format described above is simply mirroring what ModelRunner
|
||||
# produces
|
||||
|
||||
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||
|
||||
(
|
||||
decode_block_tables,
|
||||
slot_mapping_list,
|
||||
max_block_idx,
|
||||
) = make_block_tables_slot_mapping(block_size,
|
||||
qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE,
|
||||
block_base_addr=block_base_addr)
|
||||
|
||||
(
|
||||
prefill_slot_mapping,
|
||||
decode_slot_mapping,
|
||||
) = split_slot_mapping(slot_mapping_list,
|
||||
qkv.q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
|
||||
|
||||
decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
|
||||
|
||||
return (
|
||||
qkv,
|
||||
PhaseTestParameters( # Prefill test params
|
||||
PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
|
||||
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||
PhaseTestParameters( # Decode test params
|
||||
PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
|
||||
KVMemoryMap(decode_block_tables, decode_slot_mapping)),
|
||||
max_block_idx)
|
||||
|
||||
|
||||
def _enc_dec_cross_attn_setup_reuses_query(
|
||||
decoder_qkv: QKVInputs,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
prefill_decoder_phase_test_params: PhaseTestParameters,
|
||||
test_pt: TestPoint,
|
||||
test_rsrcs: TestResources,
|
||||
block_base_addr: int = 0,
|
||||
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
|
||||
'''
|
||||
Set up test vectors & data structures for cross-attention test.
|
||||
|
||||
A triplet of synthetic cross-attention key/value tensors are constructed
|
||||
("baseline" key/value). Given this is a cross-attention test, we assume
|
||||
query tensors were already synthesized for a prior self-attention test and
|
||||
will be reused for cross-attention. The key & value sequences generated here
|
||||
may have a different length than the corresponding queries (as is often
|
||||
the case for cross-attention between decoder and encoder sequences.)
|
||||
|
||||
Cross attention key & value tensors do not grow during autoregressive
|
||||
inference; thus this function obtains a single key/value pair suitable for
|
||||
both prefill and decode.
|
||||
|
||||
The "baseline" query tensor is received as an argument. The "baseline"
|
||||
query/key/value tensors are passed to an ideal reference cross-attention
|
||||
implementation to generate a "baseline" ideal output tensor. This tensor is
|
||||
split into the "Prefill" ideal output tensor (all but the last element of
|
||||
each output sequence) and the "Decode" ideal output tensor (*only* the last
|
||||
element of each output sequence); the "Prefill" and "Decode" ideal output
|
||||
tensors can be used to validate the prefill and decode test results,
|
||||
respectively.
|
||||
|
||||
This function also constructs the cross-attention KV cache memory mapping
|
||||
(slot mapping and block table), ensuring that the block table starts at
|
||||
block_base_addr.
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) decoder self-attention inputs;
|
||||
this function relies on the query and q_seq_lens
|
||||
fields
|
||||
* encoder_test_params: PhaseTestParameters data structure which was
|
||||
used for encoder inference; KV cache field
|
||||
is not used by this function
|
||||
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
|
||||
used for prefill-phase decoder
|
||||
self-attention; all fields
|
||||
including KV cache required
|
||||
* test_pt: TestPoint data structure; this function relies on the
|
||||
following fields: batch_size, num_heads, head_size,
|
||||
block_size, max_q_seq_len
|
||||
* test_rsrcs: TestResources data structure; this function relies on the
|
||||
scale field
|
||||
* block_base_addr: decoder self-attention block-table base address
|
||||
|
||||
Returns:
|
||||
|
||||
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||
structure, including (1) packed
|
||||
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||
along with (2) ideal attention output computed using a
|
||||
naive implementation, and (3) memory-mapping data structures appropriate
|
||||
for prefill phase.
|
||||
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||
structure, including (1) packed
|
||||
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||
along with (2) ideal attention output computed using a
|
||||
naive implementation, and (3) memory-mapping data structures appropriate
|
||||
for decode phase.
|
||||
'''
|
||||
|
||||
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
(
|
||||
num_heads,
|
||||
head_size,
|
||||
_,
|
||||
batch_size,
|
||||
block_size,
|
||||
max_decoder_seq_len,
|
||||
max_encoder_seq_len,
|
||||
_,
|
||||
) = test_pt
|
||||
|
||||
scale = test_rsrcs.scale
|
||||
|
||||
decoder_query = decoder_qkv.query
|
||||
decoder_seq_lens = decoder_qkv.q_seq_lens
|
||||
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||
prefill_q_seq_lens = (
|
||||
prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
|
||||
|
||||
assert prefill_q_seq_lens is not None
|
||||
|
||||
(
|
||||
cross_kv,
|
||||
_,
|
||||
_,
|
||||
) = make_qkv(batch_size,
|
||||
max_decoder_seq_len,
|
||||
max_encoder_seq_len,
|
||||
num_heads,
|
||||
head_size,
|
||||
force_kv_seq_lens=encoder_seq_lens,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
ideal_output = ref_masked_attention(decoder_query,
|
||||
cross_kv.key,
|
||||
cross_kv.value,
|
||||
scale=scale,
|
||||
q_seq_lens=decoder_seq_lens,
|
||||
kv_seq_lens=cross_kv.kv_seq_lens)
|
||||
|
||||
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
|
||||
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||
bdx, :prefill_q_seq_len]
|
||||
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||
prefill_q_seq_len + 1)]
|
||||
|
||||
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||
prefill_q_seq_lens,
|
||||
device=CUDA_DEVICE)
|
||||
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||
[1 for _ in range(batch_size)],
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Build prefill- & decode-phase data structures
|
||||
# for encoder/decoder cross-attention. Block tables and
|
||||
# slot mapping must be in a format compatible
|
||||
# with KV caching & attention kernels
|
||||
#
|
||||
# Whereas decoder self-attention extracts relationships between
|
||||
# equal-length Q/K/V sequences, which mutually grow in length
|
||||
# with each decoded token, cross-attention relates the Q sequence
|
||||
# - which grows with each new decoded token - to fixed-length
|
||||
# K and V sequences derived from the encoder hidden states.
|
||||
#
|
||||
# Prefill-phase:
|
||||
#
|
||||
# * Empty block-tables tensor
|
||||
# * Slot-mapping with as many entries as there are tokens in the encoder
|
||||
# prompt.
|
||||
#
|
||||
# Decode-phase:
|
||||
# * Block-tables tensor with minimum number of blocks to
|
||||
# accommodate K & V tensors which are equal in lnegth
|
||||
# to the encoder prompt length
|
||||
# * Empty slot-mapping tensor (since K & V are fixed in size,
|
||||
# new decoded tokens are not KV-cached and require no slot-
|
||||
# mapping)
|
||||
#
|
||||
# Note: the format above is simply an extension of what ModelRunner
|
||||
# produces for decoder-only models
|
||||
|
||||
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||
decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
|
||||
|
||||
(
|
||||
decode_block_tables,
|
||||
prefill_slot_mapping_list,
|
||||
_,
|
||||
) = make_block_tables_slot_mapping(block_size,
|
||||
cross_kv.kv_seq_lens,
|
||||
block_base_addr=block_base_addr,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# Packed key/value (query is already provided)
|
||||
packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
|
||||
|
||||
return (
|
||||
PhaseTestParameters( # Prefill-phase test params
|
||||
PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
|
||||
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||
PhaseTestParameters( # Decode-phase test params
|
||||
PackedQKVO(None, decode_packed_ideal_output),
|
||||
KVMemoryMap(decode_block_tables, decode_slot_mapping)))
|
||||
|
||||
|
||||
def _run_encoder_attention_test(
|
||||
attn: Attention,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.ENCODER in order
|
||||
to configure the kernel invocation for encoder attention
|
||||
|
||||
Requires attn_metadata.num_decode_tokens == 0
|
||||
(There is no encoder execution in the decode-phase)
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn: Attention wrapper instance
|
||||
* encoder_test_params: encoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed {query,key,value} and
|
||||
& attn_metadata
|
||||
'''
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
None,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.DECODER
|
||||
in order to configure the kernel invocation for decoder self-attention.
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||
and attn (Attention wrapper instance) fields
|
||||
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for decoder-self attention
|
||||
(contains KV cache memory-mapping)
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
& attn_metadata
|
||||
'''
|
||||
attn_type = AttentionType.DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
|
||||
Via PhaseTestParameters data structures, consumes the same query utilized
|
||||
for decoder self-attention, plus a key/value specific to cross-attention.
|
||||
|
||||
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
|
||||
is None, this reflects that in decode-phase cross attention there
|
||||
is no growth in the key and value tensors.
|
||||
|
||||
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
|
||||
in order to configure the kernel invocation for encoder/decoder cross-
|
||||
attention.
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||
and attn (Attention wrapper instance) fields
|
||||
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query field
|
||||
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
|
||||
this function relies on the packed
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
& attn_metadata
|
||||
'''
|
||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
attn_type = AttentionType.ENCODER_DECODER
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
if cross_test_params is None:
|
||||
key = None
|
||||
value = None
|
||||
else:
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
batch_size: int, block_size: int, max_dec_seq_len: int,
|
||||
max_enc_seq_len: int, monkeypatch):
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_e2e_enc_dec_attn(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
backend_name: str,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-end encoder/decoder test:
|
||||
|
||||
* Construct fake test vectors for (1) encoder attention,
|
||||
(2) decoder self-attention, and (3) encoder/decoder cross-attention
|
||||
* Construct (1) attention metadata structure with self- and cross-attention
|
||||
attributes for prefill-phase, and (2) an analogous attention metadata
|
||||
structure but for decode-phase
|
||||
* Test attention steps in the following order
|
||||
|
||||
* Encoder attention
|
||||
* Prefill self-attention
|
||||
* Prefill cross-attention
|
||||
* Decode self-attention
|
||||
* Decode cross-attention
|
||||
* Besides being reflective of realistic use-cases, this order would
|
||||
exacerbate any accidental overlap in the self-/cross-attention
|
||||
block tables, which one hopes to avoid
|
||||
|
||||
|
||||
* Validate output correctness against ideal reference attention
|
||||
implementation
|
||||
|
||||
Block tables are constructed such that cross-attention KV cache is in a
|
||||
higher, non-intersecting address-space than self-attention KV cache.
|
||||
|
||||
Self- and cross-attention share the same query tensor but not the K/V
|
||||
tensors. Self-attention K/Vs must have the same seq len as Q while
|
||||
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||
for cross-attention.
|
||||
|
||||
This test utilizes PyTest monkey patching to force the attention backend
|
||||
via an environment variable.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
|
||||
Note on metadata: there is a single attention metadata structure shared by
|
||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||
and a single one shared by all decode-phase attention operations
|
||||
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||
of ModelRunner, which constructs a single attention metadata structure for
|
||||
each prefill or decode run. A realistic scenario would rely on the
|
||||
attention backend to utilize the appropriate attention metadata fields
|
||||
according to the value of attn_metadata.attention_type. Thus, this test is
|
||||
organized so as to confirm that the backend-under-test can handle a
|
||||
shared prefill attention metadata structure & a shared decode attention
|
||||
metadata structure.
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
||||
# test params, including key/value tensors, cross-attention memory-mapping
|
||||
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
@ -1,12 +1,211 @@
|
||||
"""Kernel test utils"""
|
||||
|
||||
import pytest
|
||||
import itertools
|
||||
import random
|
||||
from numbers import Number
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
query/key/values and their sequence lengths.
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* q_seq_lens: query sequence lengths list
|
||||
* kv_seq_lens: shared key/value sequence lengths list
|
||||
'''
|
||||
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_seq_lens: List[int]
|
||||
kv_seq_lens: List[int]
|
||||
|
||||
|
||||
class QKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
alongside unpacked known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* qkv: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
qkv: QKVInputs
|
||||
ideal_output: torch.Tensor
|
||||
|
||||
|
||||
class PackedQKVInputs(NamedTuple):
|
||||
'''
|
||||
Data structure for representing packed attention inputs
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* q_start_loc_list: list of query start locations within packed tensor
|
||||
* kv_start_loc_list: shared list of key/value start locations within
|
||||
packed tensor
|
||||
* q_seq_lens: query sequence lengths list
|
||||
* kv_seq_lens: shared key/value sequence lengths list
|
||||
'''
|
||||
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_start_loc_list: Optional[List[int]]
|
||||
kv_start_loc_list: Optional[List[int]]
|
||||
q_seq_lens: Optional[List[int]]
|
||||
kv_seq_lens: Optional[List[int]]
|
||||
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing packed attention inputs,
|
||||
alongside packed known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkv: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* ideal_output: packed (number_of_tokens x num_heads
|
||||
x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
packed_qkv: Optional[PackedQKVInputs]
|
||||
ideal_output: torch.Tensor
|
||||
|
||||
|
||||
class KVMemoryMap(NamedTuple):
|
||||
'''
|
||||
Data structure for encapsulating KV cache memory mapping.
|
||||
|
||||
Attributes:
|
||||
|
||||
* block_tables: KV cache block tables
|
||||
* slot_mapping: mapping of sequence offset to physical address
|
||||
'''
|
||||
|
||||
block_tables: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
class PhaseTestParameters(NamedTuple):
|
||||
'''
|
||||
Data structure for encapsulating the test parameters
|
||||
for a given test "phase" (prefill or decode phase) and attention
|
||||
scenario (encoder, decoder-self, encoder/decoder-cross)
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs & known-correct
|
||||
output
|
||||
* kv_mmap: KV cache memory mapping, specific to this test phase &
|
||||
attention scenario
|
||||
'''
|
||||
|
||||
packed_qkvo: PackedQKVO
|
||||
kv_mmap: Optional[KVMemoryMap]
|
||||
|
||||
|
||||
def maybe_make_int_tensor(
|
||||
_list: Optional[List[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Convert Python int list to a 1D int torch.Tensor on `device`
|
||||
|
||||
Returns:
|
||||
|
||||
* If _list is not None: 1D int torch.Tensor on `device`
|
||||
* None otherwise
|
||||
'''
|
||||
return None if _list is None else torch.tensor(
|
||||
_list, dtype=torch.int, device=device)
|
||||
|
||||
|
||||
def maybe_make_long_tensor(
|
||||
_list: Optional[List[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Convert Python int list to a 1D long torch.Tensor on `device`
|
||||
|
||||
Returns:
|
||||
|
||||
* If _list is not None: 1D long torch.Tensor on `device`
|
||||
* None otherwise
|
||||
'''
|
||||
return None if _list is None else torch.tensor(
|
||||
_list, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
def maybe_max(_list: Optional[List]) -> Optional[Number]:
|
||||
'''
|
||||
Returns:
|
||||
|
||||
* If _list is not None: max(_list)
|
||||
* None otherwise
|
||||
'''
|
||||
return None if _list is None else max(_list)
|
||||
|
||||
|
||||
def make_causal_mask(
|
||||
q_max_seq_len: int,
|
||||
kv_max_seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Create a q_max_seq_len x kv_max_seq_len causal mask
|
||||
|
||||
Arguments:
|
||||
|
||||
* q_max_seq_len: query max seq len
|
||||
* kv_max_seq_len: key/value max seq len
|
||||
|
||||
Returns:
|
||||
|
||||
* 2D tensor, q_max_seq_len x kv_max_seq_len
|
||||
'''
|
||||
|
||||
# Create a matrix where entry (i, j) is True if i >= j
|
||||
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
|
||||
# Replace True with float('-inf') and False with 0
|
||||
mask = mask.masked_fill(mask == 1,
|
||||
float('-inf')).masked_fill(mask == 0, 0.0)
|
||||
return mask
|
||||
|
||||
|
||||
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
||||
backend_name: str) -> None:
|
||||
'''
|
||||
@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
||||
* backend_name: attention backend name to force
|
||||
'''
|
||||
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
||||
|
||||
|
||||
def ref_masked_attention(query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
custom_mask: Optional[torch.Tensor] = None,
|
||||
q_seq_lens: Optional[List] = None,
|
||||
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
|
||||
'''
|
||||
"Golden" masked attention reference. Supports two types of masking:
|
||||
|
||||
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
|
||||
padding elements
|
||||
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
|
||||
causal
|
||||
|
||||
Arguments:
|
||||
|
||||
* query: batch_size x q_padded_seq_len x num_heads x head_size
|
||||
* key: batch_size x kv_padded_seq_len x num_heads x head_size
|
||||
* value: batch_size x kv_padded_seq_len x num_heads x head_size
|
||||
* scale: Attention scale factor
|
||||
* custom_mask: custom attention mask; good place to inject a causal
|
||||
attention mask
|
||||
* q_seq_lens: list of unpadded query seq_lens for each batch index
|
||||
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
|
||||
|
||||
Returns:
|
||||
|
||||
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
|
||||
'''
|
||||
|
||||
assert q_seq_lens is not None
|
||||
assert kv_seq_lens is not None
|
||||
|
||||
batch_size = query.shape[0]
|
||||
assert (len(q_seq_lens) == batch_size)
|
||||
assert (len(kv_seq_lens) == batch_size)
|
||||
|
||||
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
|
||||
|
||||
# Basic attention mask, derived from seq lens
|
||||
if (q_seq_lens is not None) or (kv_seq_lens is not None):
|
||||
attn_mask = torch.zeros_like(attn_weights)
|
||||
if q_seq_lens is not None:
|
||||
for bdx, plen in enumerate(q_seq_lens):
|
||||
attn_mask[bdx, :, plen:, :] = -torch.inf
|
||||
if kv_seq_lens is not None:
|
||||
for bdx, plen in enumerate(kv_seq_lens):
|
||||
attn_mask[bdx, :, :, plen:] = -torch.inf
|
||||
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
|
||||
# Custom attention mask
|
||||
if custom_mask is not None:
|
||||
attn_weights = attn_weights + custom_mask.float()
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def make_qkv(
|
||||
batch_size: int,
|
||||
max_q_seq_len: int,
|
||||
max_kv_seq_len: Optional[int],
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
device: Union[torch.device, str],
|
||||
force_kv_seq_lens: Optional[List[int]] = None,
|
||||
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||
force_max_len: bool = False,
|
||||
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||
'''
|
||||
Construct QKV test tensors for self- and cross-attention.
|
||||
|
||||
Generates three query/key/value triplets:
|
||||
|
||||
* "Baseline" query/key/value (for input to reference attention function)
|
||||
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
|
||||
input to prefill kernel)
|
||||
* "Decode" query/key/value (only the last sequence offset from baseline,
|
||||
for use as input to decode kernel)
|
||||
|
||||
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
|
||||
seqlens
|
||||
|
||||
Arguments:
|
||||
|
||||
* batch_size
|
||||
* max_q_seq_len: max query seq len
|
||||
* max_kv_seq_len: max key/value seq len
|
||||
* num_heads
|
||||
* head_size
|
||||
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||
key/value seqlen (as is often the case for cross-attention);
|
||||
o/w, query/key/value seqlens match at each batch index
|
||||
(max_kv_seq_len is unused)
|
||||
* force_kv_seq_lens: if not None, overrides kv sequence lengths
|
||||
* attn_type: encoder, decoder self, or enc/dec cross attention
|
||||
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
|
||||
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
|
||||
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Returns:
|
||||
|
||||
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
|
||||
* Prefill QKVInputs structure (containing all but the last sequence offset)
|
||||
* Decode QKVInputs structure (containing all only the last sequence offset)
|
||||
'''
|
||||
|
||||
if force_max_len:
|
||||
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
|
||||
else:
|
||||
q_seq_lens = [
|
||||
random.randint(2, max_q_seq_len) for _ in range(batch_size)
|
||||
]
|
||||
kv_seq_lens = None
|
||||
if force_kv_seq_lens is not None:
|
||||
kv_seq_lens = force_kv_seq_lens
|
||||
elif attn_type != AttentionType.ENCODER_DECODER:
|
||||
# K,V seq lens match Q for self-attention
|
||||
kv_seq_lens = q_seq_lens
|
||||
else:
|
||||
# K,V seq lens are distinct from Q seq lens & random
|
||||
assert max_kv_seq_len is not None
|
||||
if force_max_len:
|
||||
kv_seq_lens = [max_kv_seq_len] * batch_size
|
||||
else:
|
||||
kv_seq_lens = [
|
||||
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
query = torch.rand(
|
||||
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
||||
key = torch.rand(
|
||||
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||
value = torch.rand(
|
||||
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||
|
||||
prefill_query = torch.zeros(
|
||||
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
||||
prefill_key = torch.zeros(
|
||||
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||
prefill_value = torch.zeros(
|
||||
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||
|
||||
decode_query = torch.zeros(
|
||||
(batch_size, 1, num_heads, head_size)).to(device)
|
||||
decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
|
||||
decode_value = torch.zeros(
|
||||
(batch_size, 1, num_heads, head_size)).to(device)
|
||||
|
||||
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
|
||||
kv_seq_lens)):
|
||||
query[bdx, q_seq_len:, :, :] = 0
|
||||
key[bdx, kv_seq_len:, :, :] = 0
|
||||
value[bdx, kv_seq_len:, :, :] = 0
|
||||
|
||||
prefill_query[bdx,
|
||||
0:(q_seq_len - 1), :, :] = query[bdx,
|
||||
0:(q_seq_len - 1), :, :]
|
||||
prefill_key[bdx,
|
||||
0:(kv_seq_len - 1), :, :] = key[bdx,
|
||||
0:(kv_seq_len - 1), :, :]
|
||||
prefill_value[bdx, 0:(kv_seq_len -
|
||||
1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
|
||||
|
||||
decode_query[bdx, :, :, :] = query[bdx,
|
||||
(q_seq_len - 1):q_seq_len, :, :]
|
||||
decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
|
||||
decode_value[bdx, :, :, :] = value[bdx,
|
||||
(kv_seq_len - 1):kv_seq_len, :, :]
|
||||
|
||||
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
|
||||
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
|
||||
|
||||
decode_q_seq_lens = [1 for _ in q_seq_lens]
|
||||
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
|
||||
|
||||
return (
|
||||
QKVInputs(
|
||||
query, # Overall QKV inputs
|
||||
key,
|
||||
value,
|
||||
q_seq_lens,
|
||||
kv_seq_lens),
|
||||
QKVInputs(
|
||||
prefill_query, # Prefill subset of QKV sequences
|
||||
prefill_key,
|
||||
prefill_value,
|
||||
prefill_q_seq_lens,
|
||||
prefill_kv_seq_lens),
|
||||
QKVInputs(
|
||||
decode_query, # Decode subset of KV sequences
|
||||
decode_key,
|
||||
decode_value,
|
||||
decode_q_seq_lens,
|
||||
decode_kv_seq_lens))
|
||||
|
||||
|
||||
def pack_tensor(
|
||||
unpacked_tensor: torch.Tensor, seq_lens: List[int],
|
||||
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
|
||||
'''
|
||||
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||
unpadded number_of_tokens x num_heads x head_size tensor, where
|
||||
number_of_tokens = sum(seq_lens)
|
||||
|
||||
Arguments:
|
||||
|
||||
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
|
||||
* seq_lens: list of token counts for each seq
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Returns
|
||||
|
||||
* packed_tensor: number_of_tokens x num_heads x head_size
|
||||
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
|
||||
list(itertools.accumulate(seq_lens))
|
||||
'''
|
||||
|
||||
num_tok = sum(seq_lens)
|
||||
num_heads = unpacked_tensor.shape[-2]
|
||||
head_size = unpacked_tensor.shape[-1]
|
||||
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
|
||||
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
|
||||
|
||||
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
|
||||
|
||||
packed_tensor[start_loc:(
|
||||
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
|
||||
|
||||
return packed_tensor, start_loc_list
|
||||
|
||||
|
||||
def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
|
||||
str]) -> PackedQKVInputs:
|
||||
'''
|
||||
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||
num_heads x head_size tensors.
|
||||
|
||||
For Q, number_of_tokens = sum(q_seq_lens).
|
||||
|
||||
For K and V, number_of_tokens = sum(kv_seq_lens)
|
||||
|
||||
Arguments:
|
||||
|
||||
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
|
||||
attention inputs
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Returns
|
||||
|
||||
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
|
||||
derived from unpacked inputs
|
||||
'''
|
||||
|
||||
if qkv.query is None:
|
||||
packed_query = None
|
||||
q_start_loc_list = None
|
||||
else:
|
||||
packed_query, q_start_loc_list = pack_tensor(qkv.query,
|
||||
qkv.q_seq_lens,
|
||||
device=device)
|
||||
packed_key, kv_start_loc_list = pack_tensor(qkv.key,
|
||||
qkv.kv_seq_lens,
|
||||
device=device)
|
||||
packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
|
||||
return PackedQKVInputs(
|
||||
packed_query, packed_key, packed_value, q_start_loc_list,
|
||||
kv_start_loc_list,
|
||||
(None if q_start_loc_list is None else qkv.q_seq_lens),
|
||||
qkv.kv_seq_lens)
|
||||
|
||||
|
||||
def make_backend(backend_name: str) -> AttentionBackend:
|
||||
'''
|
||||
Construct the backend instance determined by the backend_name string
|
||||
argument.
|
||||
|
||||
"XFORMERS" -> construct xformers backend
|
||||
|
||||
TODO: other backends
|
||||
|
||||
Note: at time of writing the Attention wrapper automatically selects
|
||||
its own backend for Attention.forward(); so the backend instance which
|
||||
you generate with this function is not meant to be used for *running*
|
||||
inference, but rather for generating compatible metadata structures
|
||||
using backend.make_metadata()
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
* Backend instance
|
||||
'''
|
||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||
return XFormersBackend()
|
||||
raise AssertionError(
|
||||
f"Unrecognized backend_name {backend_name} for unit test")
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
|
||||
torch.Tensor, Optional[int]]:
|
||||
'''
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
|
||||
Arguments:
|
||||
|
||||
* seq_lens: list of token-counts for each decoder input seq
|
||||
* context_lens: list of context length values for each seq
|
||||
* encoder_seq_lens: list of token-counts for each encoder input seq
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Returns:
|
||||
|
||||
* seq_lens_tensor: decoder seq_lens list, as tensor
|
||||
* context_lens_tensor: context_lens list, as tensor
|
||||
* max_context_len: max(context_lens)
|
||||
* max_seq_len: max(seq_lens)
|
||||
* seq_start_loc: start idx of each sequence
|
||||
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
||||
'''
|
||||
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
||||
context_lens_tensor = maybe_make_int_tensor(context_lens, device)
|
||||
max_context_len = maybe_max(context_lens)
|
||||
max_seq_len = maybe_max(seq_lens)
|
||||
|
||||
encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
|
||||
max_encoder_seq_len = (None if encoder_seq_lens is None else
|
||||
max(encoder_seq_lens))
|
||||
|
||||
seq_start_loc = None
|
||||
|
||||
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
||||
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
|
||||
|
||||
|
||||
def make_kv_cache(num_blocks: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
device: Union[torch.device, str],
|
||||
default_val: float = 0.0) -> torch.Tensor:
|
||||
'''
|
||||
Create a fake KV cache.
|
||||
|
||||
Arguments:
|
||||
|
||||
* num_blocks: number of blocks in the KV cache
|
||||
* num_heads: number of attention heads
|
||||
* head_size: head dimension
|
||||
* block_size: number of offsets within a block
|
||||
* device: CPU or CUDA device
|
||||
* default_val: initialization value for KV cache elements
|
||||
|
||||
Returns:
|
||||
|
||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||
'''
|
||||
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||
if default_val is not None:
|
||||
kv_cache[:, :, :] = default_val
|
||||
return kv_cache
|
||||
|
||||
|
||||
def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
|
||||
'''
|
||||
Compute the minimum number of blocks required to hold num_tokens tokens,
|
||||
given block_size
|
||||
'''
|
||||
return (num_tokens + block_size) // block_size
|
||||
|
||||
|
||||
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
|
||||
return maybe_make_long_tensor([], device)
|
||||
|
||||
|
||||
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
||||
return torch.tensor([], device=device)
|
||||
|
||||
|
||||
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
||||
device: Union[torch.device, str]):
|
||||
'''
|
||||
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||
|
||||
Context:
|
||||
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
|
||||
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
|
||||
for all N prompts (N tokens total); the resultant sequence lengths
|
||||
after decode would be {K_i + 1 for i \\in [0,N)}
|
||||
* The test you want to do requires (1) having the prefill slot mapping
|
||||
for all tokens present during prefill, the number of which is
|
||||
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||
decoded tokens
|
||||
|
||||
This function consumes a single 1D slot mapping, which is the
|
||||
concatenation of N slot mappings each of length K_i + 1 (corresponding
|
||||
to the sequence lengths after decode), with a total length of
|
||||
P = \\sum_i{K_i + 1} = M + N
|
||||
|
||||
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
|
||||
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||
decoded token's mapping.)
|
||||
|
||||
The N excised entries are appended to obtain the decode-phase slot mapping
|
||||
|
||||
Arguments:
|
||||
|
||||
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
|
||||
post-decode sequences
|
||||
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
|
||||
description above)
|
||||
* device: cuda, cpu, etc.
|
||||
|
||||
Returns:
|
||||
|
||||
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||
reflecting all N prefill prompts
|
||||
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||
all N decoded tokens
|
||||
'''
|
||||
|
||||
prefill_slot_mapping = []
|
||||
decode_slot_mapping = []
|
||||
|
||||
base_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
|
||||
seq_len - 1)])
|
||||
decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
|
||||
base_idx += seq_len
|
||||
|
||||
return (maybe_make_long_tensor(prefill_slot_mapping, device),
|
||||
maybe_make_long_tensor(decode_slot_mapping, device))
|
||||
|
||||
|
||||
def make_block_tables_slot_mapping(
|
||||
block_size: int,
|
||||
seq_lens: List[int],
|
||||
device: Union[torch.device, str],
|
||||
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
|
||||
'''
|
||||
Construct fake block tables & slot mappings.
|
||||
|
||||
For a sequence with num_tokens tokens the minimum number
|
||||
of required KV cache blocks is
|
||||
|
||||
num_blocks = (num_tokens + block_size) // block_size
|
||||
|
||||
Then the minimum KV cache size in blocks is
|
||||
|
||||
total_cache_blocks = sum(num_blocks for all seqs)
|
||||
|
||||
Then, the blocktable mapping counts downward from
|
||||
|
||||
block_base_addr + total_cache_blocks
|
||||
|
||||
to
|
||||
|
||||
block_base_addr
|
||||
|
||||
|
||||
The constructed block-tables and slot-mapping are sized to the
|
||||
lengths of the sequences in their entirety (as reflected by seq_lens),
|
||||
i.e. the total of prefill prompt tokens + decoded tokens.
|
||||
|
||||
Arguments:
|
||||
|
||||
* block_size: number of offsets per block
|
||||
* seq_lens: list of token-counts for each sequence
|
||||
* block_base_addr: the block table base address
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Return:
|
||||
|
||||
* block_tables_tensor: block table for sequence
|
||||
* slot_mapping_list: slot mapping for sequence
|
||||
* max_block_idx: the highest block address within this block table
|
||||
'''
|
||||
|
||||
# Provision minimum number of KV cache blocks
|
||||
num_blocks_list = [
|
||||
_num_tokens_to_min_blocks(num_tokens, block_size)
|
||||
for num_tokens in seq_lens
|
||||
]
|
||||
max_block_table_len = max(num_blocks_list)
|
||||
block_table_pad_tokens = 10
|
||||
|
||||
block_tables = []
|
||||
slot_mapping_list = []
|
||||
# Compute uppermost address of block table
|
||||
total_cache_blocks = sum(num_blocks_list)
|
||||
block_base_idx = block_base_addr + total_cache_blocks
|
||||
max_block_idx = block_base_idx
|
||||
for sdx, num_tokens in enumerate(seq_lens):
|
||||
num_blocks = num_blocks_list[sdx]
|
||||
block_table = list(
|
||||
range(block_base_idx, block_base_idx - num_blocks, -1))
|
||||
for idx in range(num_tokens):
|
||||
mapping_value = (
|
||||
idx % block_size) + block_table[idx // block_size] * block_size
|
||||
slot_mapping_list.append(mapping_value)
|
||||
|
||||
block_base_idx -= num_blocks
|
||||
block_tables.append(block_table)
|
||||
|
||||
block_tables_tensor = make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len + block_table_pad_tokens,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
||||
|
||||
|
||||
def make_test_metadata(
|
||||
attn_backend: AttentionBackend,
|
||||
is_prompt: bool,
|
||||
seq_lens: Optional[List[int]],
|
||||
decoder_test_params: Optional[PhaseTestParameters],
|
||||
device: Union[torch.device, str],
|
||||
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||
cross_test_params: Optional[PhaseTestParameters] = None
|
||||
) -> AttentionMetadata:
|
||||
'''
|
||||
Construct fake attention metadata for a given test phase
|
||||
(prefill-phase or decode-phase).
|
||||
|
||||
encoder_test_params and cross_test_params arguments allow encoder
|
||||
attention and enc/dec cross-attention (respectively) to use distinct
|
||||
metadata values from decoder self-attention (decoder_test_params.)
|
||||
|
||||
if encoder_test_params and cross_test_params are None, the attention
|
||||
metadata will support decoder-only scenario.
|
||||
|
||||
Assumptions:
|
||||
|
||||
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: Backend for sourcing attention kernels
|
||||
* is_prompt: prefill if True, o/w decode
|
||||
* seq_lens: list of token counts for each sequence
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
this function requires
|
||||
kv_mmap (memory mapping) field
|
||||
* device: CPU or CUDA device
|
||||
* encoder_test_params: encoder attention test params;
|
||||
this function requires encoder query
|
||||
sequence lengths field. If None,
|
||||
encoder query sequence lengths are
|
||||
treated as None
|
||||
* cross_test_params: enc/dec cross-attention test params;
|
||||
this function requires kv_mmap field.
|
||||
If None, KV cache memory map data
|
||||
structures are treated as None
|
||||
|
||||
Return:
|
||||
|
||||
* AttentionMetadata structure
|
||||
'''
|
||||
|
||||
# Decoder self-attention memory mapping
|
||||
# decoder_test_params is None signals encoder-only
|
||||
# scenario, so kv_mmap is None
|
||||
kv_mmap = (None
|
||||
if decoder_test_params is None else decoder_test_params.kv_mmap)
|
||||
|
||||
# This function constructs metadata assuming no chunked prefill,
|
||||
# i.e. 100% prefill tokens or 100% decode tokens
|
||||
#
|
||||
# - If is_prompt, num_prefills_or_decodes is the number of prefills
|
||||
# and num_prefill_or_decode_tokens is the number of prefill tokens
|
||||
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
|
||||
# and num_prefill_or_decode_tokens is the number of decode tokens
|
||||
#
|
||||
# seq_lens is None signals encoder-only
|
||||
# scenario, in which case num_prefills_or_decodes and
|
||||
# num_prefill_or_decode_tokens are unused
|
||||
num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
|
||||
|
||||
num_prefill_or_decode_tokens = (None if seq_lens is None else (
|
||||
sum(seq_lens) if is_prompt else len(seq_lens)))
|
||||
|
||||
# Seems for non-prefix-caching scenarios context_lens
|
||||
# is never needed
|
||||
context_lens = None
|
||||
|
||||
if encoder_test_params is None:
|
||||
encoder_seq_lens = None
|
||||
num_encoder_tokens = None
|
||||
else:
|
||||
# Encoder/decoder or encoder-only models only:
|
||||
# * Extract encoder input sequence lengths
|
||||
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||
num_encoder_tokens = (None if encoder_seq_lens is None else
|
||||
(sum(encoder_seq_lens)))
|
||||
|
||||
if cross_test_params is None:
|
||||
cross_kv_mmap = None
|
||||
else:
|
||||
# Encoder/decoder or encoder-only models only:
|
||||
# * Extract *cross-attention* slot_mapping and block table
|
||||
# (kv_mmap)
|
||||
cross_kv_mmap = cross_test_params.kv_mmap
|
||||
|
||||
if is_prompt:
|
||||
# Prefill-phase scenario
|
||||
|
||||
num_prefills = num_prefills_or_decodes
|
||||
num_prefill_tokens = num_prefill_or_decode_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
(
|
||||
seq_lens_tensor,
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
encoder_seq_lens,
|
||||
device=device)
|
||||
|
||||
return attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
cross_block_tables=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.block_tables))
|
||||
|
||||
else: # not is_prompt
|
||||
# Decode-phase scenario
|
||||
|
||||
assert kv_mmap is not None
|
||||
assert num_prefill_or_decode_tokens is not None
|
||||
assert seq_lens is not None
|
||||
|
||||
num_prefills = 0
|
||||
num_prefill_tokens = 0
|
||||
num_decode_tokens = num_prefill_or_decode_tokens
|
||||
|
||||
(
|
||||
seq_lens_tensor,
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
encoder_seq_lens,
|
||||
device=device)
|
||||
|
||||
return attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=max(seq_lens),
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=kv_mmap.block_tables,
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
cross_block_tables=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.block_tables))
|
||||
|
||||
|
||||
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
output_under_test: torch.Tensor) -> None:
|
||||
'''
|
||||
Assert that observed output matches the ideal output
|
||||
contained in the test parameters data structure.
|
||||
|
||||
Arguments:
|
||||
|
||||
* test_params: Test parameters including packed ideal output
|
||||
* output_under_test: actually observed output value
|
||||
'''
|
||||
ideal_output = test_params.packed_qkvo.ideal_output
|
||||
assert torch.allclose(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
@ -1,11 +1,18 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum, auto
|
||||
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
||||
ENCODER = auto() # Encoder attention between previous layer Q/K/V
|
||||
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@ -224,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: FlashInferMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
assert kv_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -157,6 +157,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
@ -170,6 +171,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert kv_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"IpexAttnBackendImpl")
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
import torch_xla.experimental.dynamo_set_buffer_donor
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@ -132,6 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
attn_metadata: PallasMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@ -146,6 +147,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
assert kv_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -297,6 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
@ -309,6 +310,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
@ -145,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
@ -158,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert kv_scale == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TorchSDPABackendImpl")
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
7
vllm/attention/backends/utils.py
Normal file
7
vllm/attention/backends/utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Attention backend utils"""
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
"with encoder/decoder models.")
|
||||
@ -6,10 +6,11 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
query_start_loc=query_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
def _get_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[AttentionBias]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
return attn_metadata.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return attn_metadata.encoder_attn_bias
|
||||
else:
|
||||
# attn_type == AttentionType.ENCODER_DECODER
|
||||
return attn_metadata.cross_attn_bias
|
||||
|
||||
|
||||
def _set_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_bias: List[Optional[AttentionBias]],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
attn_metadata.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
attn_metadata.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
attn_metadata.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: XFormersMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: "XFormersMetadata",
|
||||
kv_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* XFormersImpl.forward() may be invoked for both self- and cross-
|
||||
attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype, kv_scale)
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta)
|
||||
query, key, value, prefill_meta, attn_type=attn_type)
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
assert prefill_meta.max_query_len is not None
|
||||
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
"""
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if attn_metadata.attn_bias is None:
|
||||
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||
if attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.seq_lens)
|
||||
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Default enc/dec cross-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Default encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.encoder_seq_lens)
|
||||
else:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Default decoder self-attention mask is causal
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.seq_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
attn_metadata.attn_bias = [attn_bias]
|
||||
attn_bias = [attn_bias]
|
||||
else:
|
||||
attn_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
||||
attn_metadata.seq_lens)
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads, query.dtype,
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
||||
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_metadata.attn_bias[0],
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
return out.view_as(original_query)
|
||||
@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
assert attn_metadata.seq_lens is not None
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||
@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=attn_metadata.attn_bias[i],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -90,9 +90,16 @@ class Attention(nn.Module):
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
self._kv_scale)
|
||||
|
||||
return self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._kv_scale,
|
||||
attn_type=attn_type)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user