mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
f7a8fa39d8
commit
543aa48573
@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
|
|||||||
# Unsupported CUDA arch
|
# Unsupported CUDA arch
|
||||||
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
||||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported data type
|
# Unsupported data type
|
||||||
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported kv cache data type
|
# Unsupported kv cache data type
|
||||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported block size
|
# Unsupported block size
|
||||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported sliding window
|
# Unsupported sliding window
|
||||||
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# flash-attn is not installed
|
# flash-attn is not installed
|
||||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported head size
|
# Unsupported head size
|
||||||
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
||||||
assert backend.name != "FLASH_ATTN"
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_env(monkeypatch):
|
def test_invalid_env(monkeypatch):
|
||||||
|
|||||||
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
953
tests/kernels/test_encoder_decoder_attn.py
Normal file
@ -0,0 +1,953 @@
|
|||||||
|
"""
|
||||||
|
Tests:
|
||||||
|
|
||||||
|
* E2E test of Encoder attention + Decoder self-attention +
|
||||||
|
Encoder/decoder cross-attention (collectively
|
||||||
|
"encoder/decoder attention")
|
||||||
|
* Confirm enc/dec models will fail for chunked prefill
|
||||||
|
* Confirm enc/dec models will fail for prefix caching
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import *
|
||||||
|
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||||
|
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
HEAD_SIZES = [64, 256]
|
||||||
|
|
||||||
|
NUM_HEADS = [1, 16]
|
||||||
|
|
||||||
|
BATCH_SIZES = [1, 16]
|
||||||
|
BLOCK_SIZES = [16]
|
||||||
|
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
||||||
|
CUDA_DEVICE = "cuda:0"
|
||||||
|
|
||||||
|
MAX_DEC_SEQ_LENS = [128]
|
||||||
|
MAX_ENC_SEQ_LENS = [128]
|
||||||
|
|
||||||
|
# Narrow teest-cases for unsupported-scenario
|
||||||
|
# tests
|
||||||
|
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoint(NamedTuple):
|
||||||
|
"""
|
||||||
|
Encapsulates the attributes which define a single invocation
|
||||||
|
of the test_e2e_enc_dec_attn() test
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_heads: The number of heads in the model.
|
||||||
|
head_size: Head dimension
|
||||||
|
backend_name: Name of the backend framework used.
|
||||||
|
batch_size: Number of samples per batch.
|
||||||
|
block_size: Size of each block of data processed.
|
||||||
|
max_dec_seq_len: Maximum sequence length for the decoder.
|
||||||
|
max_enc_seq_len: Maximum sequence length for the encoder.
|
||||||
|
num_blocks: Number of blocks in the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
backend_name: str
|
||||||
|
batch_size: int
|
||||||
|
block_size: int
|
||||||
|
max_dec_seq_len: int
|
||||||
|
max_enc_seq_len: int
|
||||||
|
num_blocks: int
|
||||||
|
|
||||||
|
|
||||||
|
class TestResources(NamedTuple):
|
||||||
|
'''
|
||||||
|
Encapsulates key components for performing an
|
||||||
|
encoder/decoder attention test
|
||||||
|
|
||||||
|
Note that
|
||||||
|
(1) attn automatically selects an attention backend
|
||||||
|
based on platform info & a set of canned
|
||||||
|
heuristics
|
||||||
|
(2) attn_backend is thus *not the same backend
|
||||||
|
instance* used by attn, but rather it is
|
||||||
|
intended to be a
|
||||||
|
*different instance* of the *same backend class*;
|
||||||
|
it is assumed that the user of TestResources
|
||||||
|
will leverage attn_backend for the purpose of
|
||||||
|
constructing backend-compatible attention
|
||||||
|
metadata instances
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* scale: 1/sqrt(d) scale factor for attn
|
||||||
|
* attn_backend: implementatino of abstraction
|
||||||
|
attention interface using
|
||||||
|
a particular kernel library
|
||||||
|
i.e. XFormers
|
||||||
|
* attn: Attention layer instance
|
||||||
|
* kv_cache: shared key/value cache for all attention
|
||||||
|
'''
|
||||||
|
|
||||||
|
scale: float
|
||||||
|
attn_backend: AttentionBackend
|
||||||
|
attn: Attention
|
||||||
|
kv_cache: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||||
|
'''
|
||||||
|
Build key components for performing encoder/decoder attention test.
|
||||||
|
|
||||||
|
Note that
|
||||||
|
(1) The Attention instance constructed here, automatically selects
|
||||||
|
an attention backend class based on platform info & a set of canned
|
||||||
|
heuristics, so
|
||||||
|
(2) The attention backend instance constructed here is thus *not
|
||||||
|
the same backend instance* used by attn, but rather it is
|
||||||
|
intended to be a *different instance* of the *same backend class*;
|
||||||
|
therefore,
|
||||||
|
(3) This function requires that test_pt.backend_name matches the backend
|
||||||
|
class that Attention will automatically select when it is constructed.
|
||||||
|
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_pt: TestPoint data structure; this function relies on the
|
||||||
|
following fields: num_heads, head_size, num_blocks,
|
||||||
|
block_size, backend_name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* TestResources data structure.
|
||||||
|
'''
|
||||||
|
|
||||||
|
scale = float(1.0 / (test_pt.head_size**0.5))
|
||||||
|
attn_backend = make_backend(test_pt.backend_name)
|
||||||
|
attn = Attention(
|
||||||
|
test_pt.num_heads,
|
||||||
|
test_pt.head_size,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||||
|
# Caller does not require a KV cache
|
||||||
|
return TestResources(scale, attn_backend, attn, None)
|
||||||
|
|
||||||
|
# Construct KV cache
|
||||||
|
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||||
|
test_pt.num_heads,
|
||||||
|
test_pt.head_size,
|
||||||
|
test_pt.block_size,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def _encoder_attn_setup(
|
||||||
|
test_pt: TestPoint,
|
||||||
|
test_rsrcs: TestResources,
|
||||||
|
) -> PhaseTestParameters:
|
||||||
|
'''
|
||||||
|
Set up test vectors & data structures for encoder attention test.
|
||||||
|
|
||||||
|
A triplet of synthetic query/key/value tensors are constructed.
|
||||||
|
Given this is an encoder attention test, the key & value
|
||||||
|
sequences will have the same length as the corresponding queries.
|
||||||
|
|
||||||
|
The query/key/value tensors are passed to an ideal reference
|
||||||
|
self-attention implementation to generate an ideal output tensor.
|
||||||
|
|
||||||
|
Encoder inference does not populate the KV cache, therefore
|
||||||
|
no KV cache memory mapping is constructed
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_pt: TestPoint data structure; this function relies on the
|
||||||
|
following fields: batch_size, num_heads, head_size,
|
||||||
|
block_size, max_q_seq_len
|
||||||
|
* test_rsrcs: TestResources data structure; this function relies on the
|
||||||
|
scale field
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* PhaseTestParameters data structure comprising (1) packed query/key/value
|
||||||
|
tensors, (2) the ideal output of attention computed using a naive
|
||||||
|
implementation, and (3) KVCache field set to None
|
||||||
|
'''
|
||||||
|
|
||||||
|
(
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
_,
|
||||||
|
batch_size,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
max_q_seq_len,
|
||||||
|
_,
|
||||||
|
) = test_pt
|
||||||
|
|
||||||
|
scale = test_rsrcs.scale
|
||||||
|
|
||||||
|
max_kv_seq_len = max_q_seq_len
|
||||||
|
|
||||||
|
# Make test tensors
|
||||||
|
|
||||||
|
qkv_in, _, _ = make_qkv(batch_size,
|
||||||
|
max_q_seq_len,
|
||||||
|
max_kv_seq_len,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
attn_type=AttentionType.ENCODER,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# Compute correct answer using naive non-causal attention
|
||||||
|
# implementation
|
||||||
|
|
||||||
|
ideal_output = ref_masked_attention(qkv_in.query,
|
||||||
|
qkv_in.key,
|
||||||
|
qkv_in.value,
|
||||||
|
scale=scale,
|
||||||
|
q_seq_lens=qkv_in.q_seq_lens,
|
||||||
|
kv_seq_lens=qkv_in.kv_seq_lens)
|
||||||
|
|
||||||
|
packed_ideal_output, _ = pack_tensor(ideal_output,
|
||||||
|
qkv_in.q_seq_lens,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
return PhaseTestParameters(
|
||||||
|
PackedQKVO(packed_qkv, packed_ideal_output),
|
||||||
|
None # No KV cache
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decoder_attn_setup(
|
||||||
|
test_pt: TestPoint,
|
||||||
|
test_rsrcs: TestResources,
|
||||||
|
block_base_addr: int = 0,
|
||||||
|
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
|
||||||
|
'''
|
||||||
|
Set up test vectors & data structures for self-attention test.
|
||||||
|
|
||||||
|
A triplet of synthetic query/key/value tensors are constructed ("baseline"
|
||||||
|
query/key/value). Given this is a self-attention test, the key & value
|
||||||
|
sequences will have the same length as the corresponding queries.
|
||||||
|
|
||||||
|
"Prefill" query/key/value tensors are derived by masking out the last value
|
||||||
|
in each baseline query/key/value. These tensors are used to test prefill &
|
||||||
|
populate KV cache for a subsequent decode test.
|
||||||
|
|
||||||
|
"Decode" query/key/value tensors are derived by extracting *only* the last
|
||||||
|
value from each baseline query/key/value (i.e. complement of the prefill
|
||||||
|
tensors.) These tensors are used to test decode, conditional on the kv cache
|
||||||
|
being populated during the prefill test.
|
||||||
|
|
||||||
|
The baseline query/key/value tensors are passed to an ideal reference
|
||||||
|
self-attention implementation to generate a "Baseline" ideal output tensor.
|
||||||
|
This tensor is split into the "Prefill" ideal output tensor (all but the
|
||||||
|
last element of each output sequence) and the "Decode" ideal output tensor
|
||||||
|
(*only* the last element of each output sequence); the "Prefill" and
|
||||||
|
"Decode" ideal output tensors can be used to validate the prefill and decode
|
||||||
|
test results, respectively.
|
||||||
|
|
||||||
|
This function also constructs the self-attention KV cache memory mapping
|
||||||
|
(slot mapping and block table), ensuring that the block table starts at
|
||||||
|
block_base_addr
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_pt: TestPoint data structure; this function relies on the
|
||||||
|
following fields: batch_size, num_heads, head_size,
|
||||||
|
block_size, max_q_seq_len
|
||||||
|
* test_rsrcs: TestResources data structure; this function relies on the
|
||||||
|
scale field
|
||||||
|
* block_base_addr: decoder self-attention block-table base address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
|
||||||
|
head_size) query/key/value tensors
|
||||||
|
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
|
||||||
|
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||||
|
query/key/value tensors along with (2) ideal attention output
|
||||||
|
computed using a naive implementation, and (3) memory-mapping data
|
||||||
|
structures appropriate for prefill phase.
|
||||||
|
* Decode-phase decoder self-attention PhaseTestParameters data structure,
|
||||||
|
including (1) packed (number_of_tokens x num_heads x head_size)
|
||||||
|
query/key/value tensors along with (2) ideal attention output
|
||||||
|
computed using a naive implementation, and (3) memory-mapping data
|
||||||
|
structures appropriate for decode phase.
|
||||||
|
* max_block_idx: max physical address in decoder self-attention block-table
|
||||||
|
(intended to be used as the base address for the encoder/
|
||||||
|
decoder cross-attention block-table, which is not
|
||||||
|
constructed in this function)
|
||||||
|
'''
|
||||||
|
|
||||||
|
(
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
_,
|
||||||
|
batch_size,
|
||||||
|
block_size,
|
||||||
|
max_q_seq_len,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = test_pt
|
||||||
|
|
||||||
|
scale = test_rsrcs.scale
|
||||||
|
|
||||||
|
max_kv_seq_len = max_q_seq_len
|
||||||
|
|
||||||
|
# Build test tensors
|
||||||
|
|
||||||
|
(
|
||||||
|
qkv,
|
||||||
|
prefill_qkv,
|
||||||
|
decode_qkv,
|
||||||
|
) = make_qkv(batch_size,
|
||||||
|
max_q_seq_len,
|
||||||
|
max_kv_seq_len,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
attn_type=AttentionType.DECODER,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# Compute correct answer using naive attention implementation
|
||||||
|
# with causal attention mask
|
||||||
|
|
||||||
|
causal_mask = make_causal_mask(max_q_seq_len,
|
||||||
|
max_kv_seq_len).to(CUDA_DEVICE)
|
||||||
|
|
||||||
|
ideal_output = ref_masked_attention(qkv.query,
|
||||||
|
qkv.key,
|
||||||
|
qkv.value,
|
||||||
|
scale=scale,
|
||||||
|
custom_mask=causal_mask,
|
||||||
|
q_seq_lens=qkv.q_seq_lens,
|
||||||
|
kv_seq_lens=qkv.kv_seq_lens)
|
||||||
|
|
||||||
|
# Split out the prefill- & decode-phase ideal answers & pack them
|
||||||
|
|
||||||
|
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||||
|
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||||
|
for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
|
||||||
|
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||||
|
bdx, :prefill_q_seq_len]
|
||||||
|
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||||
|
prefill_q_seq_len + 1)]
|
||||||
|
|
||||||
|
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||||
|
prefill_qkv.q_seq_lens,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||||
|
[1 for _ in range(batch_size)],
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# Build prefill- & decode-phase data structures
|
||||||
|
# for decoder self-attention. Block tables and
|
||||||
|
# slot mapping must be in a format compatible
|
||||||
|
# with KV caching & attention kernels
|
||||||
|
#
|
||||||
|
# Prefill-phase:
|
||||||
|
#
|
||||||
|
# * Empty block-tables tensor
|
||||||
|
# * Slot-mapping with entries for prompt tokens
|
||||||
|
#
|
||||||
|
# Decode-phase:
|
||||||
|
# * Block-tables tensor with minimum number of blocks
|
||||||
|
# required by total num. tokens in the entirety of all sequences
|
||||||
|
# (including both prefill & decode)
|
||||||
|
# * Slot-mapping with entries for tokens that will be decoded in the
|
||||||
|
# current decode iteration
|
||||||
|
#
|
||||||
|
# Note: the format described above is simply mirroring what ModelRunner
|
||||||
|
# produces
|
||||||
|
|
||||||
|
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
(
|
||||||
|
decode_block_tables,
|
||||||
|
slot_mapping_list,
|
||||||
|
max_block_idx,
|
||||||
|
) = make_block_tables_slot_mapping(block_size,
|
||||||
|
qkv.q_seq_lens,
|
||||||
|
device=CUDA_DEVICE,
|
||||||
|
block_base_addr=block_base_addr)
|
||||||
|
|
||||||
|
(
|
||||||
|
prefill_slot_mapping,
|
||||||
|
decode_slot_mapping,
|
||||||
|
) = split_slot_mapping(slot_mapping_list,
|
||||||
|
qkv.q_seq_lens,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
return (
|
||||||
|
qkv,
|
||||||
|
PhaseTestParameters( # Prefill test params
|
||||||
|
PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
|
||||||
|
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||||
|
PhaseTestParameters( # Decode test params
|
||||||
|
PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
|
||||||
|
KVMemoryMap(decode_block_tables, decode_slot_mapping)),
|
||||||
|
max_block_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def _enc_dec_cross_attn_setup_reuses_query(
|
||||||
|
decoder_qkv: QKVInputs,
|
||||||
|
encoder_test_params: PhaseTestParameters,
|
||||||
|
prefill_decoder_phase_test_params: PhaseTestParameters,
|
||||||
|
test_pt: TestPoint,
|
||||||
|
test_rsrcs: TestResources,
|
||||||
|
block_base_addr: int = 0,
|
||||||
|
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
|
||||||
|
'''
|
||||||
|
Set up test vectors & data structures for cross-attention test.
|
||||||
|
|
||||||
|
A triplet of synthetic cross-attention key/value tensors are constructed
|
||||||
|
("baseline" key/value). Given this is a cross-attention test, we assume
|
||||||
|
query tensors were already synthesized for a prior self-attention test and
|
||||||
|
will be reused for cross-attention. The key & value sequences generated here
|
||||||
|
may have a different length than the corresponding queries (as is often
|
||||||
|
the case for cross-attention between decoder and encoder sequences.)
|
||||||
|
|
||||||
|
Cross attention key & value tensors do not grow during autoregressive
|
||||||
|
inference; thus this function obtains a single key/value pair suitable for
|
||||||
|
both prefill and decode.
|
||||||
|
|
||||||
|
The "baseline" query tensor is received as an argument. The "baseline"
|
||||||
|
query/key/value tensors are passed to an ideal reference cross-attention
|
||||||
|
implementation to generate a "baseline" ideal output tensor. This tensor is
|
||||||
|
split into the "Prefill" ideal output tensor (all but the last element of
|
||||||
|
each output sequence) and the "Decode" ideal output tensor (*only* the last
|
||||||
|
element of each output sequence); the "Prefill" and "Decode" ideal output
|
||||||
|
tensors can be used to validate the prefill and decode test results,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
This function also constructs the cross-attention KV cache memory mapping
|
||||||
|
(slot mapping and block table), ensuring that the block table starts at
|
||||||
|
block_base_addr.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
|
||||||
|
num_heads x head_size) decoder self-attention inputs;
|
||||||
|
this function relies on the query and q_seq_lens
|
||||||
|
fields
|
||||||
|
* encoder_test_params: PhaseTestParameters data structure which was
|
||||||
|
used for encoder inference; KV cache field
|
||||||
|
is not used by this function
|
||||||
|
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
|
||||||
|
used for prefill-phase decoder
|
||||||
|
self-attention; all fields
|
||||||
|
including KV cache required
|
||||||
|
* test_pt: TestPoint data structure; this function relies on the
|
||||||
|
following fields: batch_size, num_heads, head_size,
|
||||||
|
block_size, max_q_seq_len
|
||||||
|
* test_rsrcs: TestResources data structure; this function relies on the
|
||||||
|
scale field
|
||||||
|
* block_base_addr: decoder self-attention block-table base address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||||
|
structure, including (1) packed
|
||||||
|
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||||
|
along with (2) ideal attention output computed using a
|
||||||
|
naive implementation, and (3) memory-mapping data structures appropriate
|
||||||
|
for prefill phase.
|
||||||
|
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
|
||||||
|
structure, including (1) packed
|
||||||
|
(number_of_tokens x num_heads x head_size) query/key/value tensors
|
||||||
|
along with (2) ideal attention output computed using a
|
||||||
|
naive implementation, and (3) memory-mapping data structures appropriate
|
||||||
|
for decode phase.
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
|
||||||
|
(
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
_,
|
||||||
|
batch_size,
|
||||||
|
block_size,
|
||||||
|
max_decoder_seq_len,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
_,
|
||||||
|
) = test_pt
|
||||||
|
|
||||||
|
scale = test_rsrcs.scale
|
||||||
|
|
||||||
|
decoder_query = decoder_qkv.query
|
||||||
|
decoder_seq_lens = decoder_qkv.q_seq_lens
|
||||||
|
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||||
|
prefill_q_seq_lens = (
|
||||||
|
prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
|
||||||
|
|
||||||
|
assert prefill_q_seq_lens is not None
|
||||||
|
|
||||||
|
(
|
||||||
|
cross_kv,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = make_qkv(batch_size,
|
||||||
|
max_decoder_seq_len,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
force_kv_seq_lens=encoder_seq_lens,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
ideal_output = ref_masked_attention(decoder_query,
|
||||||
|
cross_kv.key,
|
||||||
|
cross_kv.value,
|
||||||
|
scale=scale,
|
||||||
|
q_seq_lens=decoder_seq_lens,
|
||||||
|
kv_seq_lens=cross_kv.kv_seq_lens)
|
||||||
|
|
||||||
|
prefill_ideal_output = torch.zeros_like(ideal_output)
|
||||||
|
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
|
||||||
|
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
|
||||||
|
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
|
||||||
|
bdx, :prefill_q_seq_len]
|
||||||
|
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
|
||||||
|
prefill_q_seq_len + 1)]
|
||||||
|
|
||||||
|
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
|
||||||
|
prefill_q_seq_lens,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
|
||||||
|
[1 for _ in range(batch_size)],
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# Build prefill- & decode-phase data structures
|
||||||
|
# for encoder/decoder cross-attention. Block tables and
|
||||||
|
# slot mapping must be in a format compatible
|
||||||
|
# with KV caching & attention kernels
|
||||||
|
#
|
||||||
|
# Whereas decoder self-attention extracts relationships between
|
||||||
|
# equal-length Q/K/V sequences, which mutually grow in length
|
||||||
|
# with each decoded token, cross-attention relates the Q sequence
|
||||||
|
# - which grows with each new decoded token - to fixed-length
|
||||||
|
# K and V sequences derived from the encoder hidden states.
|
||||||
|
#
|
||||||
|
# Prefill-phase:
|
||||||
|
#
|
||||||
|
# * Empty block-tables tensor
|
||||||
|
# * Slot-mapping with as many entries as there are tokens in the encoder
|
||||||
|
# prompt.
|
||||||
|
#
|
||||||
|
# Decode-phase:
|
||||||
|
# * Block-tables tensor with minimum number of blocks to
|
||||||
|
# accommodate K & V tensors which are equal in lnegth
|
||||||
|
# to the encoder prompt length
|
||||||
|
# * Empty slot-mapping tensor (since K & V are fixed in size,
|
||||||
|
# new decoded tokens are not KV-cached and require no slot-
|
||||||
|
# mapping)
|
||||||
|
#
|
||||||
|
# Note: the format above is simply an extension of what ModelRunner
|
||||||
|
# produces for decoder-only models
|
||||||
|
|
||||||
|
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
|
||||||
|
decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
(
|
||||||
|
decode_block_tables,
|
||||||
|
prefill_slot_mapping_list,
|
||||||
|
_,
|
||||||
|
) = make_block_tables_slot_mapping(block_size,
|
||||||
|
cross_kv.kv_seq_lens,
|
||||||
|
block_base_addr=block_base_addr,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# Packed key/value (query is already provided)
|
||||||
|
packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
return (
|
||||||
|
PhaseTestParameters( # Prefill-phase test params
|
||||||
|
PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
|
||||||
|
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
|
||||||
|
PhaseTestParameters( # Decode-phase test params
|
||||||
|
PackedQKVO(None, decode_packed_ideal_output),
|
||||||
|
KVMemoryMap(decode_block_tables, decode_slot_mapping)))
|
||||||
|
|
||||||
|
|
||||||
|
def _run_encoder_attention_test(
|
||||||
|
attn: Attention,
|
||||||
|
encoder_test_params: PhaseTestParameters,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Run encoder attention.
|
||||||
|
|
||||||
|
attn.forward() is passed attn_type=AttentionType.ENCODER in order
|
||||||
|
to configure the kernel invocation for encoder attention
|
||||||
|
|
||||||
|
Requires attn_metadata.num_decode_tokens == 0
|
||||||
|
(There is no encoder execution in the decode-phase)
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn: Attention wrapper instance
|
||||||
|
* encoder_test_params: encoder PhaseTestParameters data structure;
|
||||||
|
this function relies on the packed
|
||||||
|
(number_of_tokens x num_heads x head_size)
|
||||||
|
query/key/value fields
|
||||||
|
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Attention.forward() applied to packed {query,key,value} and
|
||||||
|
& attn_metadata
|
||||||
|
'''
|
||||||
|
assert attn_metadata.num_decode_tokens == 0
|
||||||
|
attn_type = AttentionType.ENCODER
|
||||||
|
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||||
|
assert packed_qkv is not None
|
||||||
|
return attn.forward(packed_qkv.query,
|
||||||
|
packed_qkv.key,
|
||||||
|
packed_qkv.value,
|
||||||
|
None,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_decoder_self_attention_test(
|
||||||
|
test_rsrcs: TestResources,
|
||||||
|
decoder_test_params: PhaseTestParameters,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Run decoder self-attention test.
|
||||||
|
|
||||||
|
attn.forward() is passed attn_type=AttentionType.DECODER
|
||||||
|
in order to configure the kernel invocation for decoder self-attention.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||||
|
and attn (Attention wrapper instance) fields
|
||||||
|
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||||
|
this function relies on the packed
|
||||||
|
(number_of_tokens x num_heads x head_size)
|
||||||
|
query/key/value fields
|
||||||
|
* attn_metadata: attention metadata for decoder-self attention
|
||||||
|
(contains KV cache memory-mapping)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||||
|
& attn_metadata
|
||||||
|
'''
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
attn = test_rsrcs.attn
|
||||||
|
kv_cache = test_rsrcs.kv_cache
|
||||||
|
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||||
|
assert packed_qkv is not None
|
||||||
|
return attn.forward(packed_qkv.query,
|
||||||
|
packed_qkv.key,
|
||||||
|
packed_qkv.value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_encoder_decoder_cross_attention_test(
|
||||||
|
test_rsrcs: TestResources,
|
||||||
|
decoder_test_params: PhaseTestParameters,
|
||||||
|
cross_test_params: Optional[PhaseTestParameters],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Run encoder/decoder cross-attention test.
|
||||||
|
|
||||||
|
Via PhaseTestParameters data structures, consumes the same query utilized
|
||||||
|
for decoder self-attention, plus a key/value specific to cross-attention.
|
||||||
|
|
||||||
|
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
|
||||||
|
is None, this reflects that in decode-phase cross attention there
|
||||||
|
is no growth in the key and value tensors.
|
||||||
|
|
||||||
|
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
|
||||||
|
in order to configure the kernel invocation for encoder/decoder cross-
|
||||||
|
attention.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_rsrcs: TestResources instance; this function relies on the kv_cache
|
||||||
|
and attn (Attention wrapper instance) fields
|
||||||
|
* decoder_test_params: decoder PhaseTestParameters data structure;
|
||||||
|
this function relies on the packed
|
||||||
|
(number_of_tokens x num_heads x head_size)
|
||||||
|
query field
|
||||||
|
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
|
||||||
|
this function relies on the packed
|
||||||
|
(number_of_tokens x num_heads x head_size)
|
||||||
|
key/value fields
|
||||||
|
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||||
|
& attn_metadata
|
||||||
|
'''
|
||||||
|
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
|
||||||
|
attn_type = AttentionType.ENCODER_DECODER
|
||||||
|
attn = test_rsrcs.attn
|
||||||
|
kv_cache = test_rsrcs.kv_cache
|
||||||
|
if cross_test_params is None:
|
||||||
|
key = None
|
||||||
|
value = None
|
||||||
|
else:
|
||||||
|
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||||
|
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||||
|
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||||
|
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||||
|
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||||
|
batch_size: int, block_size: int, max_dec_seq_len: int,
|
||||||
|
max_enc_seq_len: int, monkeypatch):
|
||||||
|
|
||||||
|
# Force Attention wrapper backend
|
||||||
|
override_backend_env_variable(monkeypatch, backend_name)
|
||||||
|
|
||||||
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
|
# is not part of this test
|
||||||
|
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||||
|
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||||
|
|
||||||
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
|
# instance, KV cache init
|
||||||
|
test_rsrcs = _make_test_resources(test_pt)
|
||||||
|
|
||||||
|
# Construct encoder attention test params (only used
|
||||||
|
# during prefill)
|
||||||
|
|
||||||
|
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||||
|
|
||||||
|
# Shared prefill metadata structure
|
||||||
|
|
||||||
|
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
|
test_rsrcs.attn_backend,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
decoder_test_params=None,
|
||||||
|
encoder_test_params=enc_test_params,
|
||||||
|
cross_test_params=None,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# PREFILL: encoder attention
|
||||||
|
|
||||||
|
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||||
|
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||||
|
|
||||||
|
# - Is encoder attention result correct?
|
||||||
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||||
|
def test_e2e_enc_dec_attn(
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
backend_name: str,
|
||||||
|
batch_size: int,
|
||||||
|
block_size: int,
|
||||||
|
max_dec_seq_len: int,
|
||||||
|
max_enc_seq_len: int,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
End-to-end encoder/decoder test:
|
||||||
|
|
||||||
|
* Construct fake test vectors for (1) encoder attention,
|
||||||
|
(2) decoder self-attention, and (3) encoder/decoder cross-attention
|
||||||
|
* Construct (1) attention metadata structure with self- and cross-attention
|
||||||
|
attributes for prefill-phase, and (2) an analogous attention metadata
|
||||||
|
structure but for decode-phase
|
||||||
|
* Test attention steps in the following order
|
||||||
|
|
||||||
|
* Encoder attention
|
||||||
|
* Prefill self-attention
|
||||||
|
* Prefill cross-attention
|
||||||
|
* Decode self-attention
|
||||||
|
* Decode cross-attention
|
||||||
|
* Besides being reflective of realistic use-cases, this order would
|
||||||
|
exacerbate any accidental overlap in the self-/cross-attention
|
||||||
|
block tables, which one hopes to avoid
|
||||||
|
|
||||||
|
|
||||||
|
* Validate output correctness against ideal reference attention
|
||||||
|
implementation
|
||||||
|
|
||||||
|
Block tables are constructed such that cross-attention KV cache is in a
|
||||||
|
higher, non-intersecting address-space than self-attention KV cache.
|
||||||
|
|
||||||
|
Self- and cross-attention share the same query tensor but not the K/V
|
||||||
|
tensors. Self-attention K/Vs must have the same seq len as Q while
|
||||||
|
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||||
|
for cross-attention.
|
||||||
|
|
||||||
|
This test utilizes PyTest monkey patching to force the attention backend
|
||||||
|
via an environment variable.
|
||||||
|
|
||||||
|
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||||
|
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||||
|
|
||||||
|
Note on metadata: there is a single attention metadata structure shared by
|
||||||
|
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||||
|
and a single one shared by all decode-phase attention operations
|
||||||
|
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||||
|
of ModelRunner, which constructs a single attention metadata structure for
|
||||||
|
each prefill or decode run. A realistic scenario would rely on the
|
||||||
|
attention backend to utilize the appropriate attention metadata fields
|
||||||
|
according to the value of attn_metadata.attention_type. Thus, this test is
|
||||||
|
organized so as to confirm that the backend-under-test can handle a
|
||||||
|
shared prefill attention metadata structure & a shared decode attention
|
||||||
|
metadata structure.
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Force Attention wrapper backend
|
||||||
|
override_backend_env_variable(monkeypatch, backend_name)
|
||||||
|
|
||||||
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
|
# is not part of this test
|
||||||
|
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||||
|
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||||
|
|
||||||
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
|
# instance, KV cache init
|
||||||
|
test_rsrcs = _make_test_resources(test_pt)
|
||||||
|
|
||||||
|
# Construct encoder attention test params (only used
|
||||||
|
# during prefill)
|
||||||
|
|
||||||
|
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||||
|
|
||||||
|
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||||
|
# test params, including query/key/value tensors, decoder self-attention
|
||||||
|
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||||
|
# decoder self-attention block-table, i.e. a base address which the
|
||||||
|
# encoder/decoder cross-attention block-table may build downward toward.
|
||||||
|
|
||||||
|
(
|
||||||
|
dec_qkv,
|
||||||
|
prephase_dec_test_params,
|
||||||
|
decphase_dec_test_params,
|
||||||
|
cross_block_base_addr,
|
||||||
|
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||||
|
|
||||||
|
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
||||||
|
# test params, including key/value tensors, cross-attention memory-mapping
|
||||||
|
|
||||||
|
(
|
||||||
|
prephase_cross_test_params,
|
||||||
|
decphase_cross_test_params,
|
||||||
|
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||||
|
dec_qkv,
|
||||||
|
enc_test_params,
|
||||||
|
prephase_dec_test_params,
|
||||||
|
test_pt,
|
||||||
|
test_rsrcs,
|
||||||
|
block_base_addr=cross_block_base_addr)
|
||||||
|
|
||||||
|
# Shared prefill metadata structure
|
||||||
|
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
|
test_rsrcs.attn_backend,
|
||||||
|
True,
|
||||||
|
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||||
|
decoder_test_params=prephase_dec_test_params,
|
||||||
|
encoder_test_params=enc_test_params,
|
||||||
|
cross_test_params=prephase_cross_test_params,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# PREFILL: encoder attention
|
||||||
|
|
||||||
|
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||||
|
enc_test_params,
|
||||||
|
prephase_attn_metadata)
|
||||||
|
|
||||||
|
# - Is encoder attention result correct?
|
||||||
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||||
|
|
||||||
|
# PREFILL: decoder self-attention test
|
||||||
|
|
||||||
|
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
|
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||||
|
|
||||||
|
# - Is prefill decoder self-attention correct?
|
||||||
|
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||||
|
prephase_dec_pckd_act_out)
|
||||||
|
|
||||||
|
# PREFILL: encoder/decoder cross-attention test
|
||||||
|
|
||||||
|
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
|
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||||
|
prephase_attn_metadata)
|
||||||
|
|
||||||
|
# - Is prefill encoder/decoder cross-attention correct?
|
||||||
|
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||||
|
prephase_cross_pckd_act_out)
|
||||||
|
|
||||||
|
# DECODE: build decode-phase attention metadata
|
||||||
|
|
||||||
|
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
|
test_rsrcs.attn_backend,
|
||||||
|
False,
|
||||||
|
dec_qkv.q_seq_lens,
|
||||||
|
decoder_test_params=decphase_dec_test_params,
|
||||||
|
encoder_test_params=enc_test_params,
|
||||||
|
cross_test_params=decphase_cross_test_params,
|
||||||
|
device=CUDA_DEVICE)
|
||||||
|
|
||||||
|
# DECODE: decoder self-attention test
|
||||||
|
|
||||||
|
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
|
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||||
|
|
||||||
|
# - Is decode-phase decoder self-attention correct?
|
||||||
|
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||||
|
decphase_dec_pckd_act_out)
|
||||||
|
|
||||||
|
# DECODE: encoder/decoder cross-attention test
|
||||||
|
|
||||||
|
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
|
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||||
|
|
||||||
|
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||||
|
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||||
|
decphase_cross_pckd_act_out)
|
||||||
@ -1,12 +1,211 @@
|
|||||||
"""Kernel test utils"""
|
"""Kernel test utils"""
|
||||||
|
|
||||||
import pytest
|
import itertools
|
||||||
|
import random
|
||||||
|
from numbers import Number
|
||||||
|
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.backends.xformers import XFormersBackend
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
|
||||||
|
# String name of register which may be set in order to
|
||||||
|
# force auto-selection of attention backend by Attention
|
||||||
|
# wrapper
|
||||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||||
|
|
||||||
|
# Possible string values of STR_BACKEND_ENV_VAR
|
||||||
|
# register, corresponding to possible backends
|
||||||
|
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||||
|
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||||
|
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||||
|
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
|
|
||||||
|
class QKVInputs(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for representing unpacked attention inputs,
|
||||||
|
query/key/values and their sequence lengths.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||||
|
num_heads x head_size) attention inputs
|
||||||
|
* q_seq_lens: query sequence lengths list
|
||||||
|
* kv_seq_lens: shared key/value sequence lengths list
|
||||||
|
'''
|
||||||
|
|
||||||
|
query: torch.Tensor
|
||||||
|
key: torch.Tensor
|
||||||
|
value: torch.Tensor
|
||||||
|
q_seq_lens: List[int]
|
||||||
|
kv_seq_lens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class QKVO(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for representing unpacked attention inputs,
|
||||||
|
alongside unpacked known-correct attention output
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* qkv: unpacked (batch_size x padded_seq_len x
|
||||||
|
num_heads x head_size) attention inputs
|
||||||
|
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||||
|
num_heads x head_size) known-correct attention output
|
||||||
|
'''
|
||||||
|
|
||||||
|
qkv: QKVInputs
|
||||||
|
ideal_output: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class PackedQKVInputs(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for representing packed attention inputs
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||||
|
x head_size) attention inputs
|
||||||
|
* q_start_loc_list: list of query start locations within packed tensor
|
||||||
|
* kv_start_loc_list: shared list of key/value start locations within
|
||||||
|
packed tensor
|
||||||
|
* q_seq_lens: query sequence lengths list
|
||||||
|
* kv_seq_lens: shared key/value sequence lengths list
|
||||||
|
'''
|
||||||
|
|
||||||
|
query: torch.Tensor
|
||||||
|
key: torch.Tensor
|
||||||
|
value: torch.Tensor
|
||||||
|
q_start_loc_list: Optional[List[int]]
|
||||||
|
kv_start_loc_list: Optional[List[int]]
|
||||||
|
q_seq_lens: Optional[List[int]]
|
||||||
|
kv_seq_lens: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class PackedQKVO(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for representing packed attention inputs,
|
||||||
|
alongside packed known-correct attention output
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* packed_qkv: packed (number_of_tokens x num_heads
|
||||||
|
x head_size) attention inputs
|
||||||
|
* ideal_output: packed (number_of_tokens x num_heads
|
||||||
|
x head_size) known-correct attention output
|
||||||
|
'''
|
||||||
|
|
||||||
|
packed_qkv: Optional[PackedQKVInputs]
|
||||||
|
ideal_output: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class KVMemoryMap(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for encapsulating KV cache memory mapping.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* block_tables: KV cache block tables
|
||||||
|
* slot_mapping: mapping of sequence offset to physical address
|
||||||
|
'''
|
||||||
|
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class PhaseTestParameters(NamedTuple):
|
||||||
|
'''
|
||||||
|
Data structure for encapsulating the test parameters
|
||||||
|
for a given test "phase" (prefill or decode phase) and attention
|
||||||
|
scenario (encoder, decoder-self, encoder/decoder-cross)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||||
|
x head_size) attention inputs & known-correct
|
||||||
|
output
|
||||||
|
* kv_mmap: KV cache memory mapping, specific to this test phase &
|
||||||
|
attention scenario
|
||||||
|
'''
|
||||||
|
|
||||||
|
packed_qkvo: PackedQKVO
|
||||||
|
kv_mmap: Optional[KVMemoryMap]
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_make_int_tensor(
|
||||||
|
_list: Optional[List[int]],
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Convert Python int list to a 1D int torch.Tensor on `device`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* If _list is not None: 1D int torch.Tensor on `device`
|
||||||
|
* None otherwise
|
||||||
|
'''
|
||||||
|
return None if _list is None else torch.tensor(
|
||||||
|
_list, dtype=torch.int, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_make_long_tensor(
|
||||||
|
_list: Optional[List[int]],
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Convert Python int list to a 1D long torch.Tensor on `device`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* If _list is not None: 1D long torch.Tensor on `device`
|
||||||
|
* None otherwise
|
||||||
|
'''
|
||||||
|
return None if _list is None else torch.tensor(
|
||||||
|
_list, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_max(_list: Optional[List]) -> Optional[Number]:
|
||||||
|
'''
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* If _list is not None: max(_list)
|
||||||
|
* None otherwise
|
||||||
|
'''
|
||||||
|
return None if _list is None else max(_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_causal_mask(
|
||||||
|
q_max_seq_len: int,
|
||||||
|
kv_max_seq_len: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Create a q_max_seq_len x kv_max_seq_len causal mask
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* q_max_seq_len: query max seq len
|
||||||
|
* kv_max_seq_len: key/value max seq len
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* 2D tensor, q_max_seq_len x kv_max_seq_len
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Create a matrix where entry (i, j) is True if i >= j
|
||||||
|
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
|
||||||
|
# Replace True with float('-inf') and False with 0
|
||||||
|
mask = mask.masked_fill(mask == 1,
|
||||||
|
float('-inf')).masked_fill(mask == 0, 0.0)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
||||||
backend_name: str) -> None:
|
backend_name: str) -> None:
|
||||||
'''
|
'''
|
||||||
@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
|||||||
* backend_name: attention backend name to force
|
* backend_name: attention backend name to force
|
||||||
'''
|
'''
|
||||||
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
||||||
|
|
||||||
|
|
||||||
|
def ref_masked_attention(query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
custom_mask: Optional[torch.Tensor] = None,
|
||||||
|
q_seq_lens: Optional[List] = None,
|
||||||
|
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
"Golden" masked attention reference. Supports two types of masking:
|
||||||
|
|
||||||
|
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
|
||||||
|
padding elements
|
||||||
|
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
|
||||||
|
causal
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* query: batch_size x q_padded_seq_len x num_heads x head_size
|
||||||
|
* key: batch_size x kv_padded_seq_len x num_heads x head_size
|
||||||
|
* value: batch_size x kv_padded_seq_len x num_heads x head_size
|
||||||
|
* scale: Attention scale factor
|
||||||
|
* custom_mask: custom attention mask; good place to inject a causal
|
||||||
|
attention mask
|
||||||
|
* q_seq_lens: list of unpadded query seq_lens for each batch index
|
||||||
|
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert q_seq_lens is not None
|
||||||
|
assert kv_seq_lens is not None
|
||||||
|
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
assert (len(q_seq_lens) == batch_size)
|
||||||
|
assert (len(kv_seq_lens) == batch_size)
|
||||||
|
|
||||||
|
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
|
||||||
|
|
||||||
|
# Basic attention mask, derived from seq lens
|
||||||
|
if (q_seq_lens is not None) or (kv_seq_lens is not None):
|
||||||
|
attn_mask = torch.zeros_like(attn_weights)
|
||||||
|
if q_seq_lens is not None:
|
||||||
|
for bdx, plen in enumerate(q_seq_lens):
|
||||||
|
attn_mask[bdx, :, plen:, :] = -torch.inf
|
||||||
|
if kv_seq_lens is not None:
|
||||||
|
for bdx, plen in enumerate(kv_seq_lens):
|
||||||
|
attn_mask[bdx, :, :, plen:] = -torch.inf
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
|
|
||||||
|
# Custom attention mask
|
||||||
|
if custom_mask is not None:
|
||||||
|
attn_weights = attn_weights + custom_mask.float()
|
||||||
|
|
||||||
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
|
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def make_qkv(
|
||||||
|
batch_size: int,
|
||||||
|
max_q_seq_len: int,
|
||||||
|
max_kv_seq_len: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
force_kv_seq_lens: Optional[List[int]] = None,
|
||||||
|
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||||
|
force_max_len: bool = False,
|
||||||
|
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||||
|
'''
|
||||||
|
Construct QKV test tensors for self- and cross-attention.
|
||||||
|
|
||||||
|
Generates three query/key/value triplets:
|
||||||
|
|
||||||
|
* "Baseline" query/key/value (for input to reference attention function)
|
||||||
|
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
|
||||||
|
input to prefill kernel)
|
||||||
|
* "Decode" query/key/value (only the last sequence offset from baseline,
|
||||||
|
for use as input to decode kernel)
|
||||||
|
|
||||||
|
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
|
||||||
|
seqlens
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* batch_size
|
||||||
|
* max_q_seq_len: max query seq len
|
||||||
|
* max_kv_seq_len: max key/value seq len
|
||||||
|
* num_heads
|
||||||
|
* head_size
|
||||||
|
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||||
|
key/value seqlen (as is often the case for cross-attention);
|
||||||
|
o/w, query/key/value seqlens match at each batch index
|
||||||
|
(max_kv_seq_len is unused)
|
||||||
|
* force_kv_seq_lens: if not None, overrides kv sequence lengths
|
||||||
|
* attn_type: encoder, decoder self, or enc/dec cross attention
|
||||||
|
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
|
||||||
|
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
|
||||||
|
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
|
||||||
|
* Prefill QKVInputs structure (containing all but the last sequence offset)
|
||||||
|
* Decode QKVInputs structure (containing all only the last sequence offset)
|
||||||
|
'''
|
||||||
|
|
||||||
|
if force_max_len:
|
||||||
|
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
|
||||||
|
else:
|
||||||
|
q_seq_lens = [
|
||||||
|
random.randint(2, max_q_seq_len) for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
kv_seq_lens = None
|
||||||
|
if force_kv_seq_lens is not None:
|
||||||
|
kv_seq_lens = force_kv_seq_lens
|
||||||
|
elif attn_type != AttentionType.ENCODER_DECODER:
|
||||||
|
# K,V seq lens match Q for self-attention
|
||||||
|
kv_seq_lens = q_seq_lens
|
||||||
|
else:
|
||||||
|
# K,V seq lens are distinct from Q seq lens & random
|
||||||
|
assert max_kv_seq_len is not None
|
||||||
|
if force_max_len:
|
||||||
|
kv_seq_lens = [max_kv_seq_len] * batch_size
|
||||||
|
else:
|
||||||
|
kv_seq_lens = [
|
||||||
|
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = torch.rand(
|
||||||
|
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
||||||
|
key = torch.rand(
|
||||||
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||||
|
value = torch.rand(
|
||||||
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||||
|
|
||||||
|
prefill_query = torch.zeros(
|
||||||
|
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
||||||
|
prefill_key = torch.zeros(
|
||||||
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||||
|
prefill_value = torch.zeros(
|
||||||
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
||||||
|
|
||||||
|
decode_query = torch.zeros(
|
||||||
|
(batch_size, 1, num_heads, head_size)).to(device)
|
||||||
|
decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
|
||||||
|
decode_value = torch.zeros(
|
||||||
|
(batch_size, 1, num_heads, head_size)).to(device)
|
||||||
|
|
||||||
|
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
|
||||||
|
kv_seq_lens)):
|
||||||
|
query[bdx, q_seq_len:, :, :] = 0
|
||||||
|
key[bdx, kv_seq_len:, :, :] = 0
|
||||||
|
value[bdx, kv_seq_len:, :, :] = 0
|
||||||
|
|
||||||
|
prefill_query[bdx,
|
||||||
|
0:(q_seq_len - 1), :, :] = query[bdx,
|
||||||
|
0:(q_seq_len - 1), :, :]
|
||||||
|
prefill_key[bdx,
|
||||||
|
0:(kv_seq_len - 1), :, :] = key[bdx,
|
||||||
|
0:(kv_seq_len - 1), :, :]
|
||||||
|
prefill_value[bdx, 0:(kv_seq_len -
|
||||||
|
1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
|
||||||
|
|
||||||
|
decode_query[bdx, :, :, :] = query[bdx,
|
||||||
|
(q_seq_len - 1):q_seq_len, :, :]
|
||||||
|
decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
|
||||||
|
decode_value[bdx, :, :, :] = value[bdx,
|
||||||
|
(kv_seq_len - 1):kv_seq_len, :, :]
|
||||||
|
|
||||||
|
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
|
||||||
|
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
|
||||||
|
|
||||||
|
decode_q_seq_lens = [1 for _ in q_seq_lens]
|
||||||
|
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
|
||||||
|
|
||||||
|
return (
|
||||||
|
QKVInputs(
|
||||||
|
query, # Overall QKV inputs
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
q_seq_lens,
|
||||||
|
kv_seq_lens),
|
||||||
|
QKVInputs(
|
||||||
|
prefill_query, # Prefill subset of QKV sequences
|
||||||
|
prefill_key,
|
||||||
|
prefill_value,
|
||||||
|
prefill_q_seq_lens,
|
||||||
|
prefill_kv_seq_lens),
|
||||||
|
QKVInputs(
|
||||||
|
decode_query, # Decode subset of KV sequences
|
||||||
|
decode_key,
|
||||||
|
decode_value,
|
||||||
|
decode_q_seq_lens,
|
||||||
|
decode_kv_seq_lens))
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(
|
||||||
|
unpacked_tensor: torch.Tensor, seq_lens: List[int],
|
||||||
|
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
'''
|
||||||
|
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||||
|
unpadded number_of_tokens x num_heads x head_size tensor, where
|
||||||
|
number_of_tokens = sum(seq_lens)
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
|
||||||
|
* seq_lens: list of token counts for each seq
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
|
||||||
|
Returns
|
||||||
|
|
||||||
|
* packed_tensor: number_of_tokens x num_heads x head_size
|
||||||
|
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
|
||||||
|
list(itertools.accumulate(seq_lens))
|
||||||
|
'''
|
||||||
|
|
||||||
|
num_tok = sum(seq_lens)
|
||||||
|
num_heads = unpacked_tensor.shape[-2]
|
||||||
|
head_size = unpacked_tensor.shape[-1]
|
||||||
|
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
|
||||||
|
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
|
||||||
|
|
||||||
|
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
|
||||||
|
|
||||||
|
packed_tensor[start_loc:(
|
||||||
|
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
|
||||||
|
|
||||||
|
return packed_tensor, start_loc_list
|
||||||
|
|
||||||
|
|
||||||
|
def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
|
||||||
|
str]) -> PackedQKVInputs:
|
||||||
|
'''
|
||||||
|
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||||
|
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||||
|
num_heads x head_size tensors.
|
||||||
|
|
||||||
|
For Q, number_of_tokens = sum(q_seq_lens).
|
||||||
|
|
||||||
|
For K and V, number_of_tokens = sum(kv_seq_lens)
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
|
||||||
|
attention inputs
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
|
||||||
|
Returns
|
||||||
|
|
||||||
|
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
|
||||||
|
derived from unpacked inputs
|
||||||
|
'''
|
||||||
|
|
||||||
|
if qkv.query is None:
|
||||||
|
packed_query = None
|
||||||
|
q_start_loc_list = None
|
||||||
|
else:
|
||||||
|
packed_query, q_start_loc_list = pack_tensor(qkv.query,
|
||||||
|
qkv.q_seq_lens,
|
||||||
|
device=device)
|
||||||
|
packed_key, kv_start_loc_list = pack_tensor(qkv.key,
|
||||||
|
qkv.kv_seq_lens,
|
||||||
|
device=device)
|
||||||
|
packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
|
||||||
|
return PackedQKVInputs(
|
||||||
|
packed_query, packed_key, packed_value, q_start_loc_list,
|
||||||
|
kv_start_loc_list,
|
||||||
|
(None if q_start_loc_list is None else qkv.q_seq_lens),
|
||||||
|
qkv.kv_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def make_backend(backend_name: str) -> AttentionBackend:
|
||||||
|
'''
|
||||||
|
Construct the backend instance determined by the backend_name string
|
||||||
|
argument.
|
||||||
|
|
||||||
|
"XFORMERS" -> construct xformers backend
|
||||||
|
|
||||||
|
TODO: other backends
|
||||||
|
|
||||||
|
Note: at time of writing the Attention wrapper automatically selects
|
||||||
|
its own backend for Attention.forward(); so the backend instance which
|
||||||
|
you generate with this function is not meant to be used for *running*
|
||||||
|
inference, but rather for generating compatible metadata structures
|
||||||
|
using backend.make_metadata()
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Backend instance
|
||||||
|
'''
|
||||||
|
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||||
|
return XFormersBackend()
|
||||||
|
raise AssertionError(
|
||||||
|
f"Unrecognized backend_name {backend_name} for unit test")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_metadata_tensors(
|
||||||
|
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
|
||||||
|
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
|
||||||
|
torch.Tensor, Optional[int]]:
|
||||||
|
'''
|
||||||
|
Build scalar & tensor values required to build attention metadata structure.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* seq_lens: list of token-counts for each decoder input seq
|
||||||
|
* context_lens: list of context length values for each seq
|
||||||
|
* encoder_seq_lens: list of token-counts for each encoder input seq
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* seq_lens_tensor: decoder seq_lens list, as tensor
|
||||||
|
* context_lens_tensor: context_lens list, as tensor
|
||||||
|
* max_context_len: max(context_lens)
|
||||||
|
* max_seq_len: max(seq_lens)
|
||||||
|
* seq_start_loc: start idx of each sequence
|
||||||
|
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
||||||
|
'''
|
||||||
|
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
||||||
|
context_lens_tensor = maybe_make_int_tensor(context_lens, device)
|
||||||
|
max_context_len = maybe_max(context_lens)
|
||||||
|
max_seq_len = maybe_max(seq_lens)
|
||||||
|
|
||||||
|
encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
|
||||||
|
max_encoder_seq_len = (None if encoder_seq_lens is None else
|
||||||
|
max(encoder_seq_lens))
|
||||||
|
|
||||||
|
seq_start_loc = None
|
||||||
|
|
||||||
|
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
||||||
|
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
def make_kv_cache(num_blocks: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
default_val: float = 0.0) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Create a fake KV cache.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* num_blocks: number of blocks in the KV cache
|
||||||
|
* num_heads: number of attention heads
|
||||||
|
* head_size: head dimension
|
||||||
|
* block_size: number of offsets within a block
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
* default_val: initialization value for KV cache elements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||||
|
'''
|
||||||
|
|
||||||
|
kv_cache = torch.rand(
|
||||||
|
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||||
|
if default_val is not None:
|
||||||
|
kv_cache[:, :, :] = default_val
|
||||||
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
|
||||||
|
'''
|
||||||
|
Compute the minimum number of blocks required to hold num_tokens tokens,
|
||||||
|
given block_size
|
||||||
|
'''
|
||||||
|
return (num_tokens + block_size) // block_size
|
||||||
|
|
||||||
|
|
||||||
|
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
|
||||||
|
return maybe_make_long_tensor([], device)
|
||||||
|
|
||||||
|
|
||||||
|
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
||||||
|
return torch.tensor([], device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
||||||
|
device: Union[torch.device, str]):
|
||||||
|
'''
|
||||||
|
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||||
|
|
||||||
|
Context:
|
||||||
|
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
|
||||||
|
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
|
||||||
|
for all N prompts (N tokens total); the resultant sequence lengths
|
||||||
|
after decode would be {K_i + 1 for i \\in [0,N)}
|
||||||
|
* The test you want to do requires (1) having the prefill slot mapping
|
||||||
|
for all tokens present during prefill, the number of which is
|
||||||
|
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||||
|
decoded tokens
|
||||||
|
|
||||||
|
This function consumes a single 1D slot mapping, which is the
|
||||||
|
concatenation of N slot mappings each of length K_i + 1 (corresponding
|
||||||
|
to the sequence lengths after decode), with a total length of
|
||||||
|
P = \\sum_i{K_i + 1} = M + N
|
||||||
|
|
||||||
|
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
|
||||||
|
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||||
|
decoded token's mapping.)
|
||||||
|
|
||||||
|
The N excised entries are appended to obtain the decode-phase slot mapping
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
|
||||||
|
post-decode sequences
|
||||||
|
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
|
||||||
|
description above)
|
||||||
|
* device: cuda, cpu, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||||
|
reflecting all N prefill prompts
|
||||||
|
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||||
|
all N decoded tokens
|
||||||
|
'''
|
||||||
|
|
||||||
|
prefill_slot_mapping = []
|
||||||
|
decode_slot_mapping = []
|
||||||
|
|
||||||
|
base_idx = 0
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
|
||||||
|
seq_len - 1)])
|
||||||
|
decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
|
||||||
|
base_idx += seq_len
|
||||||
|
|
||||||
|
return (maybe_make_long_tensor(prefill_slot_mapping, device),
|
||||||
|
maybe_make_long_tensor(decode_slot_mapping, device))
|
||||||
|
|
||||||
|
|
||||||
|
def make_block_tables_slot_mapping(
|
||||||
|
block_size: int,
|
||||||
|
seq_lens: List[int],
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
|
||||||
|
'''
|
||||||
|
Construct fake block tables & slot mappings.
|
||||||
|
|
||||||
|
For a sequence with num_tokens tokens the minimum number
|
||||||
|
of required KV cache blocks is
|
||||||
|
|
||||||
|
num_blocks = (num_tokens + block_size) // block_size
|
||||||
|
|
||||||
|
Then the minimum KV cache size in blocks is
|
||||||
|
|
||||||
|
total_cache_blocks = sum(num_blocks for all seqs)
|
||||||
|
|
||||||
|
Then, the blocktable mapping counts downward from
|
||||||
|
|
||||||
|
block_base_addr + total_cache_blocks
|
||||||
|
|
||||||
|
to
|
||||||
|
|
||||||
|
block_base_addr
|
||||||
|
|
||||||
|
|
||||||
|
The constructed block-tables and slot-mapping are sized to the
|
||||||
|
lengths of the sequences in their entirety (as reflected by seq_lens),
|
||||||
|
i.e. the total of prefill prompt tokens + decoded tokens.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* block_size: number of offsets per block
|
||||||
|
* seq_lens: list of token-counts for each sequence
|
||||||
|
* block_base_addr: the block table base address
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
* block_tables_tensor: block table for sequence
|
||||||
|
* slot_mapping_list: slot mapping for sequence
|
||||||
|
* max_block_idx: the highest block address within this block table
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Provision minimum number of KV cache blocks
|
||||||
|
num_blocks_list = [
|
||||||
|
_num_tokens_to_min_blocks(num_tokens, block_size)
|
||||||
|
for num_tokens in seq_lens
|
||||||
|
]
|
||||||
|
max_block_table_len = max(num_blocks_list)
|
||||||
|
block_table_pad_tokens = 10
|
||||||
|
|
||||||
|
block_tables = []
|
||||||
|
slot_mapping_list = []
|
||||||
|
# Compute uppermost address of block table
|
||||||
|
total_cache_blocks = sum(num_blocks_list)
|
||||||
|
block_base_idx = block_base_addr + total_cache_blocks
|
||||||
|
max_block_idx = block_base_idx
|
||||||
|
for sdx, num_tokens in enumerate(seq_lens):
|
||||||
|
num_blocks = num_blocks_list[sdx]
|
||||||
|
block_table = list(
|
||||||
|
range(block_base_idx, block_base_idx - num_blocks, -1))
|
||||||
|
for idx in range(num_tokens):
|
||||||
|
mapping_value = (
|
||||||
|
idx % block_size) + block_table[idx // block_size] * block_size
|
||||||
|
slot_mapping_list.append(mapping_value)
|
||||||
|
|
||||||
|
block_base_idx -= num_blocks
|
||||||
|
block_tables.append(block_table)
|
||||||
|
|
||||||
|
block_tables_tensor = make_tensor_with_pad(
|
||||||
|
block_tables,
|
||||||
|
max_len=max_block_table_len + block_table_pad_tokens,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_metadata(
|
||||||
|
attn_backend: AttentionBackend,
|
||||||
|
is_prompt: bool,
|
||||||
|
seq_lens: Optional[List[int]],
|
||||||
|
decoder_test_params: Optional[PhaseTestParameters],
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||||
|
cross_test_params: Optional[PhaseTestParameters] = None
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
'''
|
||||||
|
Construct fake attention metadata for a given test phase
|
||||||
|
(prefill-phase or decode-phase).
|
||||||
|
|
||||||
|
encoder_test_params and cross_test_params arguments allow encoder
|
||||||
|
attention and enc/dec cross-attention (respectively) to use distinct
|
||||||
|
metadata values from decoder self-attention (decoder_test_params.)
|
||||||
|
|
||||||
|
if encoder_test_params and cross_test_params are None, the attention
|
||||||
|
metadata will support decoder-only scenario.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
|
||||||
|
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: Backend for sourcing attention kernels
|
||||||
|
* is_prompt: prefill if True, o/w decode
|
||||||
|
* seq_lens: list of token counts for each sequence
|
||||||
|
* decoder_test_params: decoder self-attention test params;
|
||||||
|
this function requires
|
||||||
|
kv_mmap (memory mapping) field
|
||||||
|
* device: CPU or CUDA device
|
||||||
|
* encoder_test_params: encoder attention test params;
|
||||||
|
this function requires encoder query
|
||||||
|
sequence lengths field. If None,
|
||||||
|
encoder query sequence lengths are
|
||||||
|
treated as None
|
||||||
|
* cross_test_params: enc/dec cross-attention test params;
|
||||||
|
this function requires kv_mmap field.
|
||||||
|
If None, KV cache memory map data
|
||||||
|
structures are treated as None
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
* AttentionMetadata structure
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Decoder self-attention memory mapping
|
||||||
|
# decoder_test_params is None signals encoder-only
|
||||||
|
# scenario, so kv_mmap is None
|
||||||
|
kv_mmap = (None
|
||||||
|
if decoder_test_params is None else decoder_test_params.kv_mmap)
|
||||||
|
|
||||||
|
# This function constructs metadata assuming no chunked prefill,
|
||||||
|
# i.e. 100% prefill tokens or 100% decode tokens
|
||||||
|
#
|
||||||
|
# - If is_prompt, num_prefills_or_decodes is the number of prefills
|
||||||
|
# and num_prefill_or_decode_tokens is the number of prefill tokens
|
||||||
|
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
|
||||||
|
# and num_prefill_or_decode_tokens is the number of decode tokens
|
||||||
|
#
|
||||||
|
# seq_lens is None signals encoder-only
|
||||||
|
# scenario, in which case num_prefills_or_decodes and
|
||||||
|
# num_prefill_or_decode_tokens are unused
|
||||||
|
num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
|
||||||
|
|
||||||
|
num_prefill_or_decode_tokens = (None if seq_lens is None else (
|
||||||
|
sum(seq_lens) if is_prompt else len(seq_lens)))
|
||||||
|
|
||||||
|
# Seems for non-prefix-caching scenarios context_lens
|
||||||
|
# is never needed
|
||||||
|
context_lens = None
|
||||||
|
|
||||||
|
if encoder_test_params is None:
|
||||||
|
encoder_seq_lens = None
|
||||||
|
num_encoder_tokens = None
|
||||||
|
else:
|
||||||
|
# Encoder/decoder or encoder-only models only:
|
||||||
|
# * Extract encoder input sequence lengths
|
||||||
|
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||||
|
num_encoder_tokens = (None if encoder_seq_lens is None else
|
||||||
|
(sum(encoder_seq_lens)))
|
||||||
|
|
||||||
|
if cross_test_params is None:
|
||||||
|
cross_kv_mmap = None
|
||||||
|
else:
|
||||||
|
# Encoder/decoder or encoder-only models only:
|
||||||
|
# * Extract *cross-attention* slot_mapping and block table
|
||||||
|
# (kv_mmap)
|
||||||
|
cross_kv_mmap = cross_test_params.kv_mmap
|
||||||
|
|
||||||
|
if is_prompt:
|
||||||
|
# Prefill-phase scenario
|
||||||
|
|
||||||
|
num_prefills = num_prefills_or_decodes
|
||||||
|
num_prefill_tokens = num_prefill_or_decode_tokens
|
||||||
|
num_decode_tokens = 0
|
||||||
|
|
||||||
|
(
|
||||||
|
seq_lens_tensor,
|
||||||
|
context_lens_tensor,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
) = _make_metadata_tensors(seq_lens,
|
||||||
|
context_lens,
|
||||||
|
encoder_seq_lens,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
return attn_backend.make_metadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
context_lens_tensor=context_lens_tensor,
|
||||||
|
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
|
||||||
|
use_cuda_graph=False,
|
||||||
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len=max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||||
|
cross_kv_mmap.slot_mapping),
|
||||||
|
cross_block_tables=(None if cross_kv_mmap is None else
|
||||||
|
cross_kv_mmap.block_tables))
|
||||||
|
|
||||||
|
else: # not is_prompt
|
||||||
|
# Decode-phase scenario
|
||||||
|
|
||||||
|
assert kv_mmap is not None
|
||||||
|
assert num_prefill_or_decode_tokens is not None
|
||||||
|
assert seq_lens is not None
|
||||||
|
|
||||||
|
num_prefills = 0
|
||||||
|
num_prefill_tokens = 0
|
||||||
|
num_decode_tokens = num_prefill_or_decode_tokens
|
||||||
|
|
||||||
|
(
|
||||||
|
seq_lens_tensor,
|
||||||
|
context_lens_tensor,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
) = _make_metadata_tensors(seq_lens,
|
||||||
|
context_lens,
|
||||||
|
encoder_seq_lens,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
return attn_backend.make_metadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
slot_mapping=kv_mmap.slot_mapping,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=max(seq_lens),
|
||||||
|
context_lens_tensor=context_lens_tensor,
|
||||||
|
block_tables=kv_mmap.block_tables,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len=max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||||
|
cross_kv_mmap.slot_mapping),
|
||||||
|
cross_block_tables=(None if cross_kv_mmap is None else
|
||||||
|
cross_kv_mmap.block_tables))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||||
|
output_under_test: torch.Tensor) -> None:
|
||||||
|
'''
|
||||||
|
Assert that observed output matches the ideal output
|
||||||
|
contained in the test parameters data structure.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* test_params: Test parameters including packed ideal output
|
||||||
|
* output_under_test: actually observed output value
|
||||||
|
'''
|
||||||
|
ideal_output = test_params.packed_qkvo.ideal_output
|
||||||
|
assert torch.allclose(ideal_output,
|
||||||
|
output_under_test.view_as(ideal_output))
|
||||||
|
|||||||
@ -1,11 +1,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
|
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
|
||||||
TypeVar)
|
TypeVar)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType(Enum):
|
||||||
|
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
||||||
|
ENCODER = auto() # Encoder attention between previous layer Q/K/V
|
||||||
|
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
"""Abstract class for attention backends."""
|
"""Abstract class for attention backends."""
|
||||||
|
|
||||||
@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"BlocksparseFlashAttentionImpl")
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"FlashAttentionImpl")
|
||||||
|
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
|
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
@ -224,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_scale == 1.0
|
assert kv_scale == 1.0
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"FlashInferImpl")
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm._ipex_ops import ipex_ops
|
from vllm._ipex_ops import ipex_ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
|
|
||||||
@ -157,6 +157,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
|
|
||||||
@ -170,6 +171,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert kv_scale == 1.0
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"IpexAttnBackendImpl")
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
|||||||
import torch_xla.experimental.dynamo_set_buffer_donor
|
import torch_xla.experimental.dynamo_set_buffer_donor
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
@ -132,6 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
@ -146,6 +147,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert kv_scale == 1.0
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"PallasAttentionBackendImpl")
|
||||||
batch_size, seq_len, hidden_size = query.shape
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -297,6 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
@ -309,6 +310,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"ROCmFlashAttentionImpl")
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
@ -145,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
@ -158,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert kv_scale == 1.0
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"TorchSDPABackendImpl")
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
7
vllm/attention/backends/utils.py
Normal file
7
vllm/attention/backends/utils.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Attention backend utils"""
|
||||||
|
|
||||||
|
# Error string(s) for encoder/decoder
|
||||||
|
# unsupported attention scenarios
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||||
|
"with encoder/decoder models.")
|
||||||
@ -6,10 +6,11 @@ import torch
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||||
BlockDiagonalCausalMask,
|
BlockDiagonalCausalMask,
|
||||||
|
BlockDiagonalMask,
|
||||||
LowerTriangularMaskWithTensorBias)
|
LowerTriangularMaskWithTensorBias)
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
dynamically, it should be stored in tensor. The tensor has to be
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
updated from `CUDAGraphRunner.forward` API.
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
"""
|
"""
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
|
||||||
seq_lens: Optional[List[int]]
|
|
||||||
# seq_lens stored as a tensor.
|
|
||||||
seq_lens_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# |---------- N-1 iteration --------|
|
# |---------- N-1 iteration --------|
|
||||||
# |---------------- N iteration ---------------------|
|
# |---------------- N iteration ---------------------|
|
||||||
@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# |-------------------- seq_len ----------------------|
|
# |-------------------- seq_len ----------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# seq_lens stored as a tensor.
|
||||||
max_query_len: Optional[int]
|
seq_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
# FIXME: It is for flash attn.
|
# FIXME: It is for flash attn.
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
# requests only.
|
# requests only.
|
||||||
max_decode_seq_len: int
|
max_decode_seq_len: int
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
|
||||||
# is [4, 6], it is [0, 4, 10].
|
|
||||||
query_start_loc: Optional[torch.Tensor]
|
|
||||||
# FIXME: It is for flash attn.
|
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
||||||
# [4, 6], it is [0, 4, 10].
|
|
||||||
seq_start_loc: Optional[torch.Tensor]
|
|
||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
||||||
# so far).
|
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Whether or not if cuda graph is enabled.
|
# Whether or not if cuda graph is enabled.
|
||||||
# Cuda-graph is currently enabled for decoding only.
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
|
seq_lens: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# FIXME: It is for flash attn.
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
|
# so far).
|
||||||
|
context_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Maximum query length in the batch. None for decoding.
|
||||||
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
query_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Self-attention prefill/decode metadata cache
|
||||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||||
|
|
||||||
|
# Begin encoder attn & enc/dec cross-attn fields...
|
||||||
|
|
||||||
|
# Encoder sequence lengths representation
|
||||||
|
encoder_seq_lens: Optional[List[int]] = None
|
||||||
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Maximum sequence length among encoder sequences
|
||||||
|
max_encoder_seq_len: Optional[int] = None
|
||||||
|
|
||||||
|
# Number of tokens input to encoder
|
||||||
|
num_encoder_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
# Cross-attention memory-mapping data structures: slot mapping
|
||||||
|
# and block tables
|
||||||
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||||
|
cross_block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
# It is a list because it is needed to set per prompt
|
# It is a list because it is needed to set per prompt
|
||||||
@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# from xformer API.
|
# from xformer API.
|
||||||
# will not appear in the __repr__ and __init__
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||||
|
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||||
|
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_encoder_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for encoder attention is set.
|
||||||
|
'''
|
||||||
|
return ((self.encoder_seq_lens is not None)
|
||||||
|
and (self.encoder_seq_lens_tensor is not None)
|
||||||
|
and (self.max_encoder_seq_len is not None))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_cross_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for enc/dec cross-attention is set.
|
||||||
|
|
||||||
|
Superset of encoder attention required metadata.
|
||||||
|
'''
|
||||||
|
return (self.is_all_encoder_attn_metadata_set
|
||||||
|
and (self.cross_slot_mapping is not None)
|
||||||
|
and (self.cross_block_tables is not None))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||||
@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self._cached_prefill_metadata is not None:
|
if self._cached_prefill_metadata is not None:
|
||||||
|
# Recover cached prefill-phase attention
|
||||||
|
# metadata structure
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
assert ((self.seq_lens is not None)
|
||||||
assert self.seq_lens_tensor is not None
|
or (self.encoder_seq_lens is not None))
|
||||||
assert self.query_start_loc is not None
|
assert ((self.seq_lens_tensor is not None)
|
||||||
assert self.context_lens_tensor is not None
|
or (self.encoder_seq_lens_tensor is not None))
|
||||||
assert self.block_tables is not None
|
|
||||||
|
|
||||||
|
# Compute some attn_metadata fields which default to None
|
||||||
|
query_start_loc = (None if self.query_start_loc is None else
|
||||||
|
self.query_start_loc[:self.num_prefills + 1])
|
||||||
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
|
self.slot_mapping[:self.num_prefill_tokens])
|
||||||
|
seq_lens = (None if self.seq_lens is None else
|
||||||
|
self.seq_lens[:self.num_prefills])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[:self.num_prefills])
|
||||||
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||||
|
self.context_lens_tensor[:self.num_prefills])
|
||||||
|
block_tables = (None if self.block_tables is None else
|
||||||
|
self.block_tables[:self.num_prefills])
|
||||||
|
|
||||||
|
# Construct & cache prefill-phase attention metadata structure
|
||||||
self._cached_prefill_metadata = XFormersMetadata(
|
self._cached_prefill_metadata = XFormersMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=self.seq_lens[:self.num_prefills],
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
query_start_loc=query_start_loc,
|
||||||
seq_start_loc=None,
|
context_lens_tensor=context_lens_tensor,
|
||||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
block_tables=block_tables,
|
||||||
block_tables=self.block_tables[:self.num_prefills],
|
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
)
|
# Begin encoder & cross attn fields below...
|
||||||
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self._cached_decode_metadata is not None:
|
if self._cached_decode_metadata is not None:
|
||||||
|
# Recover cached decode-phase attention
|
||||||
|
# metadata structure
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
assert self.block_tables is not None
|
assert ((self.seq_lens_tensor is not None)
|
||||||
assert self.seq_lens_tensor is not None
|
or (self.encoder_seq_lens_tensor is not None))
|
||||||
|
|
||||||
|
# Compute some attn_metadata fields which default to None
|
||||||
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
|
self.slot_mapping[self.num_prefill_tokens:])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[self.num_prefills:])
|
||||||
|
block_tables = (None if self.block_tables is None else
|
||||||
|
self.block_tables[self.num_prefills:])
|
||||||
|
|
||||||
|
# Construct & cache decode-phase attention metadata structure
|
||||||
self._cached_decode_metadata = XFormersMetadata(
|
self._cached_decode_metadata = XFormersMetadata(
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=None,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
|
||||||
max_query_len=None,
|
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
query_start_loc=None,
|
block_tables=block_tables,
|
||||||
seq_start_loc=None,
|
|
||||||
context_lens_tensor=None,
|
|
||||||
block_tables=self.block_tables[self.num_prefills:],
|
|
||||||
use_cuda_graph=self.use_cuda_graph,
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
)
|
# Begin encoder & cross attn fields below...
|
||||||
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attn_bias(
|
||||||
|
attn_metadata: XFormersMetadata,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> Optional[AttentionBias]:
|
||||||
|
'''
|
||||||
|
Extract appropriate attention bias from attention metadata
|
||||||
|
according to attention type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Appropriate attention bias value given the attention type
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
return attn_metadata.attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
return attn_metadata.encoder_attn_bias
|
||||||
|
else:
|
||||||
|
# attn_type == AttentionType.ENCODER_DECODER
|
||||||
|
return attn_metadata.cross_attn_bias
|
||||||
|
|
||||||
|
|
||||||
|
def _set_attn_bias(
|
||||||
|
attn_metadata: XFormersMetadata,
|
||||||
|
attn_bias: List[Optional[AttentionBias]],
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Update appropriate attention bias field of attention metadata,
|
||||||
|
according to attention type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* attn_bias: The desired attention bias value
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
attn_metadata.attn_bias = attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
attn_metadata.encoder_attn_bias = attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
attn_metadata.cross_attn_bias = attn_bias
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_seq_len_block_table_args(
|
||||||
|
attn_metadata: XFormersMetadata,
|
||||||
|
is_prompt: bool,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> tuple:
|
||||||
|
'''
|
||||||
|
The particular choice of sequence-length- and block-table-related
|
||||||
|
attributes which should be extracted from attn_metadata is dependent
|
||||||
|
on the type of attention operation.
|
||||||
|
|
||||||
|
Decoder attn -> select entirely decoder self-attention-related fields
|
||||||
|
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||||
|
cross-attn block-tables fields
|
||||||
|
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention op
|
||||||
|
* is_prompt: True if prefill, False otherwise
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Appropriate sequence-lengths tensor
|
||||||
|
* Appropriate max sequence-length scalar
|
||||||
|
* Appropriate block tables (or None)
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Decoder self-attention
|
||||||
|
# Choose max_seq_len based on whether we are in prompt_run
|
||||||
|
if is_prompt:
|
||||||
|
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||||
|
else:
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||||
|
attn_metadata.block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||||
|
# cross-attention utilizes special "cross" block tables
|
||||||
|
return (attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.cross_block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
# No block tables associated with encoder attention
|
||||||
|
return (attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len, None)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||||
"""
|
"""
|
||||||
If the input tensors contain prompt tokens, the layout is as follows:
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: Optional[torch.Tensor],
|
||||||
value: torch.Tensor,
|
value: Optional[torch.Tensor],
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
|
For decoder-only models: query, key and value must be non-None.
|
||||||
|
|
||||||
|
For encoder/decoder models:
|
||||||
|
* XFormersImpl.forward() may be invoked for both self- and cross-
|
||||||
|
attention layers.
|
||||||
|
* For self-attention: query, key and value must be non-None.
|
||||||
|
* For cross-attention:
|
||||||
|
* Query must be non-None
|
||||||
|
* During prefill, key and value must be non-None; key and value
|
||||||
|
get cached for use during decode.
|
||||||
|
* During decode, key and value may be None, since:
|
||||||
|
(1) key and value tensors were cached during prefill, and
|
||||||
|
(2) cross-attention key and value tensors do not grow during
|
||||||
|
decode
|
||||||
|
|
||||||
|
A note on how the attn_type (attention type enum) argument impacts
|
||||||
|
attention forward() behavior:
|
||||||
|
|
||||||
|
* DECODER: normal decoder-only behavior;
|
||||||
|
use decoder self-attention block table
|
||||||
|
* ENCODER: no KV caching; pass encoder sequence
|
||||||
|
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||||
|
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||||
|
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
||||||
|
* ENCODER_DECODER: cross-attention behavior;
|
||||||
|
use cross-attention block table for caching KVs derived
|
||||||
|
from encoder hidden states; since KV sequence lengths
|
||||||
|
will match encoder sequence lengths, pass encoder sequence
|
||||||
|
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||||
|
max_encoder_seq_len)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
|
attn_type: Select attention type, between encoder attention,
|
||||||
|
decoder self-attention, or encoder/decoder cross-
|
||||||
|
attention. Defaults to decoder self-attention,
|
||||||
|
which is the vLLM default generally
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
|
|
||||||
if kv_cache is not None:
|
# Check that appropriate attention metadata attributes are
|
||||||
|
# selected for the desired attention type
|
||||||
|
if (attn_type == AttentionType.ENCODER
|
||||||
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
|
raise AttributeError("Encoder attention requires setting "
|
||||||
|
"encoder metadata attributes.")
|
||||||
|
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||||
|
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||||
|
raise AttributeError("Encoder/decoder cross-attention "
|
||||||
|
"requires setting cross-attention "
|
||||||
|
"metadata attributes.")
|
||||||
|
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
if key is not None:
|
||||||
|
assert value is not None
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
else:
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
# Self-attention vs. cross-attention will impact
|
||||||
|
# which KV cache memory-mapping & which
|
||||||
|
# seqlen datastructures we utilize
|
||||||
|
|
||||||
|
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
|
||||||
|
# KV-cache during decoder-self- or
|
||||||
|
# encoder-decoder-cross-attention, but not
|
||||||
|
# during encoder attention.
|
||||||
|
#
|
||||||
|
# Even if there are no new key/value pairs to cache,
|
||||||
|
# we still need to break out key_cache and value_cache
|
||||||
|
# i.e. for later use by paged attention
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
if (key is not None) and (value is not None):
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
|
||||||
# not cached. This happens during the initial memory profiling run.
|
|
||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype, kv_scale)
|
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
if attn_type == AttentionType.ENCODER_DECODER:
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
# Update cross-attention KV cache (prefill-only)
|
||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
# During cross-attention decode, key & value will be None,
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
# preventing this IF-statement branch from running
|
||||||
|
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
else:
|
||||||
|
# Update self-attention KV cache (prefill/decode)
|
||||||
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
|
# not cached. This happens during the initial memory
|
||||||
|
# profiling run.
|
||||||
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
|
value_cache,
|
||||||
|
updated_slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
kv_scale)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.ENCODER:
|
||||||
|
# Decoder self-attention supports chunked prefill.
|
||||||
|
# Encoder/decoder cross-attention requires no chunked
|
||||||
|
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
else:
|
||||||
|
# Encoder attention - chunked prefill is not applicable;
|
||||||
|
# derive token-count from query shape & and treat them
|
||||||
|
# as 100% prefill tokens
|
||||||
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
|
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_decode_tokens = 0
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Only enforce this shape-constraint for decoder
|
||||||
|
# self-attention
|
||||||
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
decode_query = query[num_prefill_tokens:]
|
decode_query = query[num_prefill_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
query = query[:num_prefill_tokens]
|
query = query[:num_prefill_tokens]
|
||||||
key = key[:num_prefill_tokens]
|
if key is not None and value is not None:
|
||||||
value = value[:num_prefill_tokens]
|
key = key[:num_prefill_tokens]
|
||||||
|
value = value[:num_prefill_tokens]
|
||||||
|
|
||||||
assert query.shape[0] == num_prefill_tokens
|
assert query.shape[0] == num_prefill_tokens
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# block tables are empty if the prompt does not have a cached
|
# block tables are empty if the prompt does not have a cached
|
||||||
# prefix.
|
# prefix.
|
||||||
out = self._run_memory_efficient_xformers_forward(
|
out = self._run_memory_efficient_xformers_forward(
|
||||||
query, key, value, prefill_meta)
|
query, key, value, prefill_meta, attn_type=attn_type)
|
||||||
assert out.shape == output[:num_prefill_tokens].shape
|
assert out.shape == output[:num_prefill_tokens].shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
assert prefill_meta.query_start_loc is not None
|
||||||
|
assert prefill_meta.max_query_len is not None
|
||||||
|
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||||
# deal with different data types between KV and FP8 KV cache,
|
# deal with different data types between KV and FP8 KV cache,
|
||||||
@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
|
||||||
|
(
|
||||||
|
seq_lens_arg,
|
||||||
|
max_seq_len_arg,
|
||||||
|
block_tables_arg,
|
||||||
|
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
|
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||||
decode_query,
|
decode_query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
decode_meta.block_tables,
|
block_tables_arg,
|
||||||
decode_meta.seq_lens_tensor,
|
seq_lens_arg,
|
||||||
decode_meta.max_decode_seq_len,
|
max_seq_len_arg,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attn_metadata: XFormersMetadata,
|
attn_metadata: XFormersMetadata,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||||
tokens are flattened in to `query` input.
|
tokens are flattened in to `query` input.
|
||||||
@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
|
attn_type: Select attention type, between encoder attention,
|
||||||
|
decoder self-attention, or encoder/decoder cross-
|
||||||
|
attention. Defaults to decoder self-attention,
|
||||||
|
which is the vLLM default generally
|
||||||
"""
|
"""
|
||||||
assert attn_metadata.seq_lens is not None
|
|
||||||
original_query = query
|
original_query = query
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||||
@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# Set attention bias if not provided. This typically happens at
|
# Set attention bias if not provided. This typically happens at
|
||||||
# the very attention layer of every iteration.
|
# the very attention layer of every iteration.
|
||||||
# FIXME(woosuk): This is a hack.
|
# FIXME(woosuk): This is a hack.
|
||||||
if attn_metadata.attn_bias is None:
|
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||||
|
if attn_bias is None:
|
||||||
if self.alibi_slopes is None:
|
if self.alibi_slopes is None:
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||||
attn_metadata.seq_lens)
|
assert attn_metadata.seq_lens is not None
|
||||||
|
assert attn_metadata.encoder_seq_lens is not None
|
||||||
|
|
||||||
|
# Default enc/dec cross-attention mask is non-causal
|
||||||
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
|
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
assert attn_metadata.encoder_seq_lens is not None
|
||||||
|
|
||||||
|
# Default encoder self-attention mask is non-causal
|
||||||
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
|
attn_metadata.encoder_seq_lens)
|
||||||
|
else:
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
|
|
||||||
|
# Default decoder self-attention mask is causal
|
||||||
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
attn_metadata.seq_lens)
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
attn_bias = attn_bias.make_local_attention(
|
attn_bias = attn_bias.make_local_attention(
|
||||||
self.sliding_window)
|
self.sliding_window)
|
||||||
attn_metadata.attn_bias = [attn_bias]
|
attn_bias = [attn_bias]
|
||||||
else:
|
else:
|
||||||
attn_metadata.attn_bias = _make_alibi_bias(
|
assert attn_metadata.seq_lens is not None
|
||||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||||
attn_metadata.seq_lens)
|
self.num_kv_heads, query.dtype,
|
||||||
|
attn_metadata.seq_lens)
|
||||||
|
|
||||||
|
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
||||||
|
|
||||||
# No alibi slopes.
|
# No alibi slopes.
|
||||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||||
@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_bias=attn_metadata.attn_bias[0],
|
attn_bias=attn_bias[0],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale)
|
scale=self.scale)
|
||||||
return out.view_as(original_query)
|
return out.view_as(original_query)
|
||||||
@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||||
# lengths with custom attention bias, we process each prompt one by
|
# lengths with custom attention bias, we process each prompt one by
|
||||||
# one. This is inefficient, especially when we have many short prompts.
|
# one. This is inefficient, especially when we have many short prompts.
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
output = torch.empty_like(original_query)
|
output = torch.empty_like(original_query)
|
||||||
start = 0
|
start = 0
|
||||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||||
@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
query[None, start:end],
|
query[None, start:end],
|
||||||
key[None, start:end],
|
key[None, start:end],
|
||||||
value[None, start:end],
|
value[None, start:end],
|
||||||
attn_bias=attn_metadata.attn_bias[i],
|
attn_bias=attn_bias[i],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale)
|
scale=self.scale)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -90,9 +90,16 @@ class Attention(nn.Module):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
|
||||||
self._kv_scale)
|
return self.impl.forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
self._kv_scale,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user