mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:57:45 +08:00
[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
32c9eff2ff
commit
e20c92bb61
@ -13,8 +13,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import *
|
from tests.kernels.utils import *
|
||||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
AttentionType)
|
|
||||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||||
global_force_attn_backend_context_manager)
|
global_force_attn_backend_context_manager)
|
||||||
@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
|
|||||||
max_dec_seq_len: int
|
max_dec_seq_len: int
|
||||||
max_enc_seq_len: int
|
max_enc_seq_len: int
|
||||||
num_blocks: int
|
num_blocks: int
|
||||||
|
attn_type: AttentionType
|
||||||
|
|
||||||
|
|
||||||
class TestResources(NamedTuple):
|
class TestResources(NamedTuple):
|
||||||
@ -96,7 +96,6 @@ class TestResources(NamedTuple):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
scale: float
|
scale: float
|
||||||
attn_backend: AttentionBackend
|
|
||||||
attn: Attention
|
attn: Attention
|
||||||
kv_cache: torch.Tensor
|
kv_cache: torch.Tensor
|
||||||
|
|
||||||
@ -129,16 +128,17 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
scale = float(1.0 / (test_pt.head_size**0.5))
|
scale = float(1.0 / (test_pt.head_size**0.5))
|
||||||
attn_backend = make_backend(test_pt.backend_name)
|
|
||||||
attn = Attention(
|
attn = Attention(
|
||||||
test_pt.num_heads,
|
test_pt.num_heads,
|
||||||
test_pt.head_size,
|
test_pt.head_size,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
prefix=f"{test_pt.attn_type}",
|
||||||
|
attn_type=test_pt.attn_type,
|
||||||
)
|
)
|
||||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||||
# Caller does not require a KV cache
|
# Caller does not require a KV cache
|
||||||
return TestResources(
|
return TestResources(
|
||||||
scale, attn_backend, attn,
|
scale, attn,
|
||||||
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||||
|
|
||||||
# Construct KV cache
|
# Construct KV cache
|
||||||
@ -148,7 +148,7 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
|||||||
test_pt.block_size,
|
test_pt.block_size,
|
||||||
device=CUDA_DEVICE,
|
device=CUDA_DEVICE,
|
||||||
backend=test_pt.backend_name)
|
backend=test_pt.backend_name)
|
||||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
return TestResources(scale, attn, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
def _encoder_attn_setup(
|
def _encoder_attn_setup(
|
||||||
@ -193,6 +193,7 @@ def _encoder_attn_setup(
|
|||||||
_,
|
_,
|
||||||
max_q_seq_len,
|
max_q_seq_len,
|
||||||
_,
|
_,
|
||||||
|
_,
|
||||||
) = test_pt
|
) = test_pt
|
||||||
|
|
||||||
scale = test_rsrcs.scale
|
scale = test_rsrcs.scale
|
||||||
@ -301,6 +302,7 @@ def _decoder_attn_setup(
|
|||||||
max_q_seq_len,
|
max_q_seq_len,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
|
_,
|
||||||
) = test_pt
|
) = test_pt
|
||||||
|
|
||||||
scale = test_rsrcs.scale
|
scale = test_rsrcs.scale
|
||||||
@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
|
|||||||
max_decoder_seq_len,
|
max_decoder_seq_len,
|
||||||
max_encoder_seq_len,
|
max_encoder_seq_len,
|
||||||
_,
|
_,
|
||||||
|
_,
|
||||||
) = test_pt
|
) = test_pt
|
||||||
|
|
||||||
scale = test_rsrcs.scale
|
scale = test_rsrcs.scale
|
||||||
@ -622,7 +625,6 @@ def _run_encoder_attention_test(
|
|||||||
& attn_metadata
|
& attn_metadata
|
||||||
'''
|
'''
|
||||||
assert attn_metadata.num_decode_tokens == 0
|
assert attn_metadata.num_decode_tokens == 0
|
||||||
attn_type = AttentionType.ENCODER
|
|
||||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
with set_forward_context(attn_metadata, vllm_config):
|
with set_forward_context(attn_metadata, vllm_config):
|
||||||
@ -635,14 +637,11 @@ def _run_encoder_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = packed_qkv.query.view(
|
reshaped_query = packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(reshaped_query,
|
return attn.forward(
|
||||||
packed_qkv.key,
|
reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||||
packed_qkv.value,
|
torch.tensor([],
|
||||||
torch.tensor([],
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
device=packed_qkv.query.device), attn_metadata)
|
||||||
device=packed_qkv.query.device),
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=attn_type)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_decoder_self_attention_test(
|
def _run_decoder_self_attention_test(
|
||||||
@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
|
|||||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||||
& attn_metadata
|
& attn_metadata
|
||||||
'''
|
'''
|
||||||
attn_type = AttentionType.DECODER
|
|
||||||
attn = test_rsrcs.attn
|
attn = test_rsrcs.attn
|
||||||
kv_cache = test_rsrcs.kv_cache
|
kv_cache = test_rsrcs.kv_cache
|
||||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||||
@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = packed_qkv.query.view(
|
reshaped_query = packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(reshaped_query,
|
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||||
packed_qkv.key,
|
kv_cache, attn_metadata)
|
||||||
packed_qkv.value,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=attn_type)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_encoder_decoder_cross_attention_test(
|
def _run_encoder_decoder_cross_attention_test(
|
||||||
@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
'''
|
'''
|
||||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
|
||||||
attn_type = AttentionType.ENCODER_DECODER
|
|
||||||
attn = test_rsrcs.attn
|
attn = test_rsrcs.attn
|
||||||
kv_cache = test_rsrcs.kv_cache
|
kv_cache = test_rsrcs.kv_cache
|
||||||
if cross_test_params is None:
|
if cross_test_params is None:
|
||||||
@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(reshaped_query,
|
return attn.forward(reshaped_query, key, value, kv_cache,
|
||||||
key,
|
attn_metadata)
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=attn_type)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -839,7 +828,7 @@ def test_encoder_only(
|
|||||||
# is not part of this test
|
# is not part of this test
|
||||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
batch_size, block_size, max_dec_seq_len,
|
batch_size, block_size, max_dec_seq_len,
|
||||||
max_enc_seq_len, 4096)
|
max_enc_seq_len, 4096, AttentionType.ENCODER)
|
||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
@ -855,7 +844,7 @@ def test_encoder_only(
|
|||||||
# Shared prefill metadata structure
|
# Shared prefill metadata structure
|
||||||
|
|
||||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
test_rsrcs.attn_backend,
|
attn_backend,
|
||||||
True,
|
True,
|
||||||
None,
|
None,
|
||||||
decoder_test_params=None,
|
decoder_test_params=None,
|
||||||
@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
|
|||||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
# to be more than necessary, since exceeding the kv cache size
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
# is not part of this test
|
# is not part of this test
|
||||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
batch_size, block_size, max_dec_seq_len,
|
batch_size, block_size, max_dec_seq_len,
|
||||||
max_enc_seq_len, 4096)
|
max_enc_seq_len, 4096, AttentionType.ENCODER)
|
||||||
|
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
|
batch_size, block_size, max_dec_seq_len,
|
||||||
|
max_enc_seq_len, 4096,
|
||||||
|
AttentionType.ENCODER_DECODER)
|
||||||
|
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
|
batch_size, block_size, max_dec_seq_len,
|
||||||
|
max_enc_seq_len, 4096, AttentionType.DECODER)
|
||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
test_rsrcs = _make_test_resources(test_pt)
|
enc_test_rsrcs = _make_test_resources(enc_test_pt)
|
||||||
|
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
|
||||||
|
dec_test_rsrcs = _make_test_resources(dec_test_pt)
|
||||||
|
|
||||||
# Construct encoder attention test params (only used
|
# Construct encoder attention test params (only used
|
||||||
# during prefill)
|
# during prefill)
|
||||||
|
|
||||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
|
||||||
|
|
||||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||||
# test params, including query/key/value tensors, decoder self-attention
|
# test params, including query/key/value tensors, decoder self-attention
|
||||||
@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
|
|||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
decphase_dec_test_params,
|
decphase_dec_test_params,
|
||||||
cross_block_base_addr,
|
cross_block_base_addr,
|
||||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
|
||||||
|
|
||||||
# Construct encoder/decoder cross-attention prefill-phase
|
# Construct encoder/decoder cross-attention prefill-phase
|
||||||
# & decode-phase test params, including key/value tensors,
|
# & decode-phase test params, including key/value tensors,
|
||||||
@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
|
|||||||
dec_qkv,
|
dec_qkv,
|
||||||
enc_test_params,
|
enc_test_params,
|
||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
test_pt,
|
enc_dec_test_pt,
|
||||||
test_rsrcs,
|
enc_dec_test_rsrcs,
|
||||||
block_base_addr=cross_block_base_addr)
|
block_base_addr=cross_block_base_addr)
|
||||||
|
|
||||||
# Shared prefill metadata structure
|
# Shared prefill metadata structure
|
||||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
test_rsrcs.attn_backend,
|
attn_backend,
|
||||||
True,
|
True,
|
||||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||||
decoder_test_params=prephase_dec_test_params,
|
decoder_test_params=prephase_dec_test_params,
|
||||||
@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
|
|||||||
|
|
||||||
# PREFILL: encoder attention
|
# PREFILL: encoder attention
|
||||||
|
|
||||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
|
||||||
enc_test_params,
|
enc_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt,
|
test_pt=enc_test_pt,
|
||||||
vllm_config=vllm_config)
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is encoder attention result correct?
|
# - Is encoder attention result correct?
|
||||||
@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
|
|||||||
# PREFILL: decoder self-attention test
|
# PREFILL: decoder self-attention test
|
||||||
|
|
||||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
test_rsrcs,
|
dec_test_rsrcs,
|
||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt,
|
test_pt=dec_test_pt,
|
||||||
vllm_config=vllm_config)
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is prefill decoder self-attention correct?
|
# - Is prefill decoder self-attention correct?
|
||||||
@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
|
|||||||
# PREFILL: encoder/decoder cross-attention test
|
# PREFILL: encoder/decoder cross-attention test
|
||||||
|
|
||||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
test_rsrcs,
|
enc_dec_test_rsrcs,
|
||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
prephase_cross_test_params,
|
prephase_cross_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt,
|
test_pt=enc_dec_test_pt,
|
||||||
vllm_config=vllm_config)
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is prefill encoder/decoder cross-attention correct?
|
# - Is prefill encoder/decoder cross-attention correct?
|
||||||
@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
|
|||||||
# DECODE: build decode-phase attention metadata
|
# DECODE: build decode-phase attention metadata
|
||||||
|
|
||||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||||
test_rsrcs.attn_backend,
|
attn_backend,
|
||||||
False,
|
False,
|
||||||
dec_qkv.q_seq_lens,
|
dec_qkv.q_seq_lens,
|
||||||
decoder_test_params=decphase_dec_test_params,
|
decoder_test_params=decphase_dec_test_params,
|
||||||
@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
|
|||||||
# DECODE: decoder self-attention test
|
# DECODE: decoder self-attention test
|
||||||
|
|
||||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
test_rsrcs,
|
dec_test_rsrcs,
|
||||||
decphase_dec_test_params,
|
decphase_dec_test_params,
|
||||||
decphase_attn_metadata,
|
decphase_attn_metadata,
|
||||||
test_pt=test_pt,
|
test_pt=dec_test_pt,
|
||||||
vllm_config=vllm_config)
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is decode-phase decoder self-attention correct?
|
# - Is decode-phase decoder self-attention correct?
|
||||||
@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
|
|||||||
# DECODE: encoder/decoder cross-attention test
|
# DECODE: encoder/decoder cross-attention test
|
||||||
|
|
||||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
test_rsrcs,
|
enc_dec_test_rsrcs,
|
||||||
decphase_dec_test_params,
|
decphase_dec_test_params,
|
||||||
None,
|
None,
|
||||||
decphase_attn_metadata,
|
decphase_attn_metadata,
|
||||||
test_pt=test_pt,
|
test_pt=enc_dec_test_pt,
|
||||||
vllm_config=vllm_config)
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
|
|||||||
|
|
||||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.platforms.interface import _Backend
|
||||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||||
|
|
||||||
@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
|
|||||||
|
|
||||||
|
|
||||||
def make_test_metadata(
|
def make_test_metadata(
|
||||||
attn_backend: AttentionBackend,
|
attn_backend: _Backend,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
seq_lens: Optional[List[int]],
|
seq_lens: Optional[List[int]],
|
||||||
decoder_test_params: Optional[PhaseTestParameters],
|
decoder_test_params: Optional[PhaseTestParameters],
|
||||||
@ -815,7 +816,7 @@ def make_test_metadata(
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* attn_backend: Backend for sourcing attention kernels
|
* attn_backend_name: Backend for sourcing attention kernels
|
||||||
* is_prompt: prefill if True, o/w decode
|
* is_prompt: prefill if True, o/w decode
|
||||||
* seq_lens: list of token counts for each sequence
|
* seq_lens: list of token counts for each sequence
|
||||||
* decoder_test_params: decoder self-attention test params;
|
* decoder_test_params: decoder self-attention test params;
|
||||||
@ -882,6 +883,8 @@ def make_test_metadata(
|
|||||||
# (kv_mmap)
|
# (kv_mmap)
|
||||||
cross_kv_mmap = cross_test_params.kv_mmap
|
cross_kv_mmap = cross_test_params.kv_mmap
|
||||||
|
|
||||||
|
attn_backend_obj = make_backend(attn_backend.name)
|
||||||
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prefill-phase scenario
|
# Prefill-phase scenario
|
||||||
|
|
||||||
@ -902,8 +905,7 @@ def make_test_metadata(
|
|||||||
context_lens,
|
context_lens,
|
||||||
encoder_seq_lens,
|
encoder_seq_lens,
|
||||||
device=device)
|
device=device)
|
||||||
|
return attn_backend_obj.make_metadata(
|
||||||
return attn_backend.make_metadata(
|
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
@ -952,7 +954,7 @@ def make_test_metadata(
|
|||||||
encoder_seq_lens,
|
encoder_seq_lens,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
return attn_backend.make_metadata(
|
return attn_backend_obj.make_metadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
slot_mapping=kv_mmap.slot_mapping,
|
slot_mapping=kv_mmap.slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
|||||||
@ -233,6 +233,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str = "auto",
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -246,7 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -300,6 +300,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is not None
|
assert blocksparse_params is not None
|
||||||
assert alibi_slopes is None, ValueError(
|
assert alibi_slopes is None, ValueError(
|
||||||
@ -350,6 +351,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
active_head_range=self.blocksparse_params.active_head_range,
|
active_head_range=self.blocksparse_params.active_head_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"BlocksparseFlashAttentionImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -359,7 +366,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -375,12 +381,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"BlocksparseFlashAttentionImpl")
|
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -600,6 +600,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -627,6 +628,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
f"Head size {head_size} is not supported by FlashAttention. "
|
||||||
f"Supported head sizes are: {support_head_sizes}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -637,7 +639,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
@ -659,6 +660,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
attn_type = self.attn_type
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
raise AttributeError("Encoder attention requires setting "
|
raise AttributeError("Encoder attention requires setting "
|
||||||
|
|||||||
@ -748,6 +748,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -764,6 +765,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"FlashInferImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -773,18 +780,10 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# TODO: directly write to output tensor
|
# TODO: directly write to output tensor
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"FlashInferImpl")
|
|
||||||
|
|
||||||
num_heads: int = self.num_heads
|
num_heads: int = self.num_heads
|
||||||
head_size: int = self.head_size
|
head_size: int = self.head_size
|
||||||
num_kv_heads: int = self.num_kv_heads
|
num_kv_heads: int = self.num_kv_heads
|
||||||
|
|||||||
@ -102,6 +102,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
max_seq_len: int = 4096,
|
max_seq_len: int = 4096,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(AttentionImpl, self).__init__()
|
super(AttentionImpl, self).__init__()
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
@ -143,6 +144,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"HPUAttentionImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -152,7 +159,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
attn_metadata: HPUAttentionMetadata,
|
attn_metadata: HPUAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
@ -166,11 +172,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"HPUAttentionImpl")
|
|
||||||
batch_size, seq_len, hidden_size = query.shape
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
_, seq_len_kv, _ = key.shape
|
_, seq_len_kv, _ = key.shape
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -146,6 +147,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"IPEX backend does not support FP8 KV cache. "
|
"IPEX backend does not support FP8 KV cache. "
|
||||||
"Please use xFormers backend instead.")
|
"Please use xFormers backend instead.")
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"IpexAttnBackendImpl")
|
||||||
|
|
||||||
def split_kv_cache(
|
def split_kv_cache(
|
||||||
self,
|
self,
|
||||||
@ -172,7 +178,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
@ -189,11 +194,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"IpexAttnBackendImpl")
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -100,6 +100,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -141,6 +142,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
# megacore mode will be None.
|
# megacore mode will be None.
|
||||||
self.megacore_mode = "batch"
|
self.megacore_mode = "batch"
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"PallasAttentionBackendImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -150,7 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
@ -168,11 +174,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"PallasAttentionBackendImpl")
|
|
||||||
batch_size, seq_len, hidden_size = query.shape
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
|||||||
@ -338,6 +338,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -397,6 +398,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.attn_func = _sdpa_attention
|
self.attn_func = _sdpa_attention
|
||||||
logger.debug("Using naive attention in ROCmBackend")
|
logger.debug("Using naive attention in ROCmBackend")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"ROCmFlashAttentionImpl")
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
tokens, n_kv_heads, head_dim = x.shape
|
tokens, n_kv_heads, head_dim = x.shape
|
||||||
@ -414,7 +421,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -432,12 +438,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"ROCmFlashAttentionImpl")
|
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@ -390,6 +390,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -421,6 +422,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Torch SDPA backend does not support FP8 KV cache. "
|
"Torch SDPA backend does not support FP8 KV cache. "
|
||||||
"Please use xFormers backend instead.")
|
"Please use xFormers backend instead.")
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -431,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
@ -448,6 +449,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
|
attn_type = self.attn_type
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
raise AttributeError("Encoder attention requires setting "
|
raise AttributeError("Encoder attention requires setting "
|
||||||
|
|||||||
@ -379,6 +379,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -405,6 +406,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||||
|
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -414,7 +417,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
@ -468,7 +470,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
attn_type = self.attn_type
|
||||||
# Check that appropriate attention metadata attributes are
|
# Check that appropriate attention metadata attributes are
|
||||||
# selected for the desired attention type
|
# selected for the desired attention type
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
|
|||||||
@ -41,6 +41,7 @@ class Attention(nn.Module):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
per_layer_sliding_window: Optional[int] = None,
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if per_layer_sliding_window is not None:
|
if per_layer_sliding_window is not None:
|
||||||
@ -96,7 +97,7 @@ class Attention(nn.Module):
|
|||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap)
|
blocksparse_params, logits_soft_cap, attn_type)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
@ -119,6 +120,7 @@ class Attention(nn.Module):
|
|||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
self.layer_name = prefix
|
self.layer_name = prefix
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -127,18 +129,12 @@ class Attention(nn.Module):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
return self.impl.forward(query,
|
return self.impl.forward(query, key, value, kv_cache,
|
||||||
key,
|
attn_metadata, self._k_scale,
|
||||||
value,
|
self._v_scale)
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
self._k_scale,
|
|
||||||
self._v_scale,
|
|
||||||
attn_type=attn_type)
|
|
||||||
elif self.use_output:
|
elif self.use_output:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
hidden_size = query.size(-1)
|
hidden_size = query.size(-1)
|
||||||
@ -152,13 +148,11 @@ class Attention(nn.Module):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
query, key, value, output, kv_cache, attn_type,
|
query, key, value, output, kv_cache, self.layer_name)
|
||||||
self.layer_name)
|
|
||||||
return output.view(-1, hidden_size)
|
return output.view(-1, hidden_size)
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.unified_attention(query, key, value,
|
return torch.ops.vllm.unified_attention(query, key, value,
|
||||||
kv_cache, attn_type,
|
kv_cache, self.layer_name)
|
||||||
self.layer_name)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
@ -237,20 +231,13 @@ def unified_attention(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_type: str,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.dynamic_forward_context
|
attn_metadata = forward_context.dynamic_forward_context
|
||||||
self = forward_context.static_forward_context[layer_name]
|
self = forward_context.static_forward_context[layer_name]
|
||||||
return self.impl.forward(query,
|
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||||
key,
|
self._k_scale, self._v_scale)
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
self._k_scale,
|
|
||||||
self._v_scale,
|
|
||||||
attn_type=attn_type)
|
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_fake(
|
def unified_attention_fake(
|
||||||
@ -258,7 +245,6 @@ def unified_attention_fake(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_type: str,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(query).contiguous()
|
return torch.empty_like(query).contiguous()
|
||||||
@ -279,7 +265,6 @@ def unified_attention_with_output(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_type: str,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
@ -292,7 +277,6 @@ def unified_attention_with_output(
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
self._k_scale,
|
self._k_scale,
|
||||||
self._v_scale,
|
self._v_scale,
|
||||||
attn_type=attn_type,
|
|
||||||
output=output)
|
output=output)
|
||||||
|
|
||||||
|
|
||||||
@ -302,7 +286,6 @@ def unified_attention_with_output_fake(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_type: str,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -71,12 +71,8 @@ class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
attn_type: AttentionType,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
||||||
|
|
||||||
assert attn_type != AttentionType.ENCODER_DECODER
|
|
||||||
|
|
||||||
return super().forward(positions + self.offset)
|
return super().forward(positions + self.offset)
|
||||||
|
|
||||||
|
|
||||||
@ -180,7 +176,8 @@ class BartEncoderAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=AttentionType.ENCODER)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
@ -189,12 +186,7 @@ class BartEncoderAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=AttentionType.ENCODER)
|
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -264,7 +256,8 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=AttentionType.DECODER)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
@ -273,12 +266,7 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=AttentionType.DECODER)
|
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -348,7 +336,8 @@ class BartCrossAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -372,12 +361,7 @@ class BartCrossAttention(nn.Module):
|
|||||||
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=AttentionType.ENCODER_DECODER)
|
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -644,10 +628,7 @@ class BartEncoder(nn.Module):
|
|||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
embed_pos = self.embed_positions(
|
embed_pos = self.embed_positions(positions)
|
||||||
positions,
|
|
||||||
AttentionType.ENCODER,
|
|
||||||
)
|
|
||||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + embed_pos
|
hidden_states = inputs_embeds + embed_pos
|
||||||
@ -734,10 +715,7 @@ class BartDecoder(nn.Module):
|
|||||||
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
embed_pos = self.embed_positions(
|
embed_pos = self.embed_positions(decoder_positions)
|
||||||
decoder_positions,
|
|
||||||
AttentionType.DECODER,
|
|
||||||
)
|
|
||||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + embed_pos
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
|||||||
@ -238,7 +238,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -248,12 +249,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
output = self.attn(q,
|
output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=AttentionType.ENCODER_ONLY)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -770,6 +770,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
self.num_local_key_value_heads,
|
self.num_local_key_value_heads,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -805,13 +806,9 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
kv_range_for_decode,
|
kv_range_for_decode,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
else:
|
else:
|
||||||
output = self.attn(q.view(-1,
|
output = self.attn(
|
||||||
self.num_local_heads * self.head_dim),
|
q.view(-1, self.num_local_heads * self.head_dim), k, v,
|
||||||
k,
|
kv_cache, attn_metadata)
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=AttentionType.ENCODER_DECODER)
|
|
||||||
out, _ = self.o_proj(output)
|
out, _ = self.o_proj(output)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -107,7 +107,8 @@ class Qwen2Attention(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
rope_scaling: Optional[Tuple] = None,
|
rope_scaling: Optional[Tuple] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -160,7 +161,8 @@ class Qwen2Attention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -168,17 +170,11 @@ class Qwen2Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=attn_type)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -197,6 +193,16 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||||
|
# You can override the HF config with `is_causal=False` to enable
|
||||||
|
# bidirectional attention, which is used in some embedding models
|
||||||
|
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||||
|
if getattr(config, "is_causal", True):
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
else:
|
||||||
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
@ -207,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
attn_type=attn_type,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2MLP(
|
self.mlp = Qwen2MLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -220,15 +227,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
|
||||||
# You can override the HF config with `is_causal=False` to enable
|
|
||||||
# bidirectional attention, which is used in some embedding models
|
|
||||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
|
||||||
if getattr(config, "is_causal", True):
|
|
||||||
self._attn_type = AttentionType.DECODER
|
|
||||||
else:
|
|
||||||
self._attn_type = AttentionType.ENCODER_ONLY
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -249,7 +247,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
attn_type=self._attn_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -119,6 +120,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
f"Head size {head_size} is not supported by FlashAttention. "
|
||||||
f"Supported head sizes are: {support_head_sizes}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"FlashAttentionImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -128,7 +135,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
@ -142,12 +148,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"FlashAttentionImpl")
|
|
||||||
|
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user