[Kernel] Move attn_type to Attention.__init__() (#11690)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-01-07 00:11:28 +08:00 committed by GitHub
parent 32c9eff2ff
commit e20c92bb61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 159 additions and 201 deletions

View File

@ -13,8 +13,7 @@ import pytest
import torch
from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
attn_type: AttentionType
class TestResources(NamedTuple):
@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''
scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor
@ -129,16 +128,17 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
prefix=f"{test_pt.attn_type}",
attn_type=test_pt.attn_type,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(
scale, attn_backend, attn,
scale, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache
@ -148,7 +148,7 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)
return TestResources(scale, attn, kv_cache)
def _encoder_attn_setup(
@ -193,6 +193,7 @@ def _encoder_attn_setup(
_,
max_q_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
@ -301,6 +302,7 @@ def _decoder_attn_setup(
max_q_seq_len,
_,
_,
_,
) = test_pt
scale = test_rsrcs.scale
@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len,
max_encoder_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
@ -622,7 +625,6 @@ def _run_encoder_attention_test(
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config):
@ -635,14 +637,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
return attn.forward(
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
def _run_decoder_self_attention_test(
@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
kv_cache, attn_metadata)
def _run_encoder_decoder_cross_attention_test(
@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, key, value, kv_cache,
attn_metadata)
@pytest.fixture(autouse=True)
@ -839,7 +828,7 @@ def test_encoder_only(
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
max_enc_seq_len, 4096, AttentionType.ENCODER)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
@ -855,7 +844,7 @@ def test_encoder_only(
# Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
None,
decoder_test_params=None,
@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.ENCODER)
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096,
AttentionType.ENCODER_DECODER)
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.DECODER)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
enc_test_rsrcs = _make_test_resources(enc_test_pt)
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
dec_test_rsrcs = _make_test_resources(dec_test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
enc_dec_test_pt,
enc_dec_test_rsrcs,
block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct?
@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct?
@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct?
@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct?
@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct?

View File

@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: AttentionBackend,
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
@ -815,7 +816,7 @@ def make_test_metadata(
Arguments:
* attn_backend: Backend for sourcing attention kernels
* attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)
if is_prompt:
# Prefill-phase scenario
@ -902,8 +905,7 @@ def make_test_metadata(
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
@ -952,7 +954,7 @@ def make_test_metadata(
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,

View File

@ -233,6 +233,7 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError
@ -246,7 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -300,6 +300,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
@ -350,6 +351,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
active_head_range=self.blocksparse_params.active_head_range,
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
def forward(
self,
query: torch.Tensor,
@ -359,7 +366,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -375,12 +381,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)

View File

@ -600,6 +600,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -627,6 +628,7 @@ class FlashAttentionImpl(AttentionImpl):
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
def forward(
self,
@ -637,7 +639,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -659,6 +660,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided."
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "

View File

@ -748,6 +748,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -764,6 +765,12 @@ class FlashInferImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
def forward(
self,
query: torch.Tensor,
@ -773,18 +780,10 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: directly write to output tensor
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads

View File

@ -102,6 +102,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
attn_type: str = AttentionType.DECODER,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
@ -143,6 +144,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
def forward(
self,
query: torch.Tensor,
@ -152,7 +159,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -166,11 +172,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

View File

@ -115,6 +115,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -146,6 +147,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
def split_kv_cache(
self,
@ -172,7 +178,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
@ -189,11 +194,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)

View File

@ -100,6 +100,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -141,6 +142,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
def forward(
self,
query: torch.Tensor,
@ -150,7 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@ -168,11 +174,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

View File

@ -338,6 +338,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -397,6 +398,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
@ -414,7 +421,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -432,12 +438,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)

View File

@ -390,6 +390,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -421,6 +422,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type
def forward(
self,
@ -431,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -448,6 +449,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "

View File

@ -379,6 +379,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -405,6 +406,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.attn_type = attn_type
def forward(
self,
query: torch.Tensor,
@ -414,7 +417,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -468,7 +470,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER

View File

@ -41,6 +41,7 @@ class Attention(nn.Module):
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
@ -96,7 +97,7 @@ class Attention(nn.Module):
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
blocksparse_params, logits_soft_cap, attn_type)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
@ -119,6 +120,7 @@ class Attention(nn.Module):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
def forward(
self,
@ -127,18 +129,12 @@ class Attention(nn.Module):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
if self.use_direct_call:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
@ -152,13 +148,11 @@ class Attention(nn.Module):
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
query, key, value, output, kv_cache, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
kv_cache, self.layer_name)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
@ -237,20 +231,13 @@ def unified_attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)
def unified_attention_fake(
@ -258,7 +245,6 @@ def unified_attention_fake(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
@ -279,7 +265,6 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
@ -292,7 +277,6 @@ def unified_attention_with_output(
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
output=output)
@ -302,7 +286,6 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
return

View File

@ -71,12 +71,8 @@ class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
def forward(
self,
positions: torch.Tensor,
attn_type: AttentionType,
) -> torch.Tensor:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert attn_type != AttentionType.ENCODER_DECODER
return super().forward(positions + self.offset)
@ -180,7 +176,8 @@ class BartEncoderAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -189,12 +186,7 @@ class BartEncoderAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -264,7 +256,8 @@ class BartDecoderSelfAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -273,12 +266,7 @@ class BartDecoderSelfAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.DECODER)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -348,7 +336,8 @@ class BartCrossAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER)
def forward(
self,
@ -372,12 +361,7 @@ class BartCrossAttention(nn.Module):
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -644,10 +628,7 @@ class BartEncoder(nn.Module):
# retrieve input_ids and inputs_embeds
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(
positions,
AttentionType.ENCODER,
)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
@ -734,10 +715,7 @@ class BartDecoder(nn.Module):
inputs_embeds = self.embed_tokens(decoder_input_ids)
# embed positions
embed_pos = self.embed_positions(
decoder_positions,
AttentionType.DECODER,
)
embed_pos = self.embed_positions(decoder_positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos

View File

@ -238,7 +238,8 @@ class BertSelfAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
def forward(
self,
@ -248,12 +249,7 @@ class BertSelfAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_ONLY)
output = self.attn(q, k, v, kv_cache, attn_metadata)
return output

View File

@ -770,6 +770,7 @@ class MllamaTextCrossAttention(nn.Module):
self.scaling,
self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER,
)
def forward(
@ -805,13 +806,9 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode,
attn_metadata)
else:
output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v,
kv_cache, attn_metadata)
out, _ = self.o_proj(output)
return out

View File

@ -107,7 +107,8 @@ class Qwen2Attention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "") -> None:
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@ -160,7 +161,8 @@ class Qwen2Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=attn_type)
def forward(
self,
@ -168,17 +170,11 @@ class Qwen2Attention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -197,6 +193,16 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -207,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@ -220,15 +227,6 @@ class Qwen2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY
def forward(
self,
positions: torch.Tensor,
@ -249,7 +247,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
attn_type=self._attn_type,
)
# Fully Connected

View File

@ -89,6 +89,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
@ -119,6 +120,12 @@ class FlashAttentionImpl(AttentionImpl):
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
def forward(
self,
query: torch.Tensor,
@ -128,7 +135,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -142,12 +148,6 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")