mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 18:35:41 +08:00
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
This commit is contained in:
parent
d522034c85
commit
a78dd3303e
@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
from vllm.attention.selector import (_Backend,
|
||||||
|
global_force_attn_backend_context_manager)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ..conftest import DecoderPromptType
|
from ..conftest import DecoderPromptType
|
||||||
from ..models.utils import check_logprobs_close
|
from ..models.utils import check_logprobs_close
|
||||||
|
|
||||||
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [
|
||||||
|
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(
|
def vllm_to_hf_output(
|
||||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||||
@ -29,7 +35,8 @@ def vllm_to_hf_output(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||||
@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
|
|||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
decoder_prompt_type: DecoderPromptType,
|
decoder_prompt_type: DecoderPromptType,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
attn_backend: _Backend,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
End-to-End (E2E) test for the encoder-decoder framework.
|
End-to-End (E2E) test for the encoder-decoder framework.
|
||||||
@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
|
|||||||
implementations to ensure that both implementations produce consistent
|
implementations to ensure that both implementations produce consistent
|
||||||
and correct results.
|
and correct results.
|
||||||
'''
|
'''
|
||||||
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
|
if attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
# Flash Attention works only with bfloat16 data-type
|
||||||
|
dtype = 'bfloat16'
|
||||||
|
test_case_prompts = example_encoder_decoder_prompts[
|
||||||
|
decoder_prompt_type]
|
||||||
|
|
||||||
# Configuration settings for HF baseline
|
# Configuration settings for HF baseline
|
||||||
hf_kwargs = {
|
hf_kwargs = {
|
||||||
"top_k": None,
|
"top_k": None,
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"length_penalty": 1.0,
|
"length_penalty": 1.0,
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"no_repeat_ngram_size": None,
|
"no_repeat_ngram_size": None,
|
||||||
"min_length": 0
|
"min_length": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype,
|
with hf_runner(model, dtype=dtype,
|
||||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
hf_outputs = (
|
||||||
test_case_prompts,
|
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
max_tokens,
|
test_case_prompts,
|
||||||
num_logprobs,
|
max_tokens,
|
||||||
**hf_kwargs,
|
num_logprobs,
|
||||||
))
|
**hf_kwargs,
|
||||||
with vllm_runner(model, dtype=dtype,
|
))
|
||||||
enforce_eager=enforce_eager) as vllm_model:
|
with vllm_runner(model, dtype=dtype,
|
||||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
enforce_eager=enforce_eager) as vllm_model:
|
||||||
test_case_prompts, max_tokens, num_logprobs)
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
|
test_case_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
hf_skip_tokens = (1
|
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||||
if decoder_prompt_type == DecoderPromptType.NONE else 0)
|
else 0)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=[
|
outputs_1_lst=[
|
||||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||||
for vllm_output in vllm_outputs
|
for vllm_output in vllm_outputs
|
||||||
],
|
],
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,13 +16,13 @@ from tests.kernels.utils import *
|
|||||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||||
from vllm.attention.selector import (_Backend,
|
from vllm.attention.selector import (_Backend, get_attn_backend,
|
||||||
global_force_attn_backend_context_manager)
|
global_force_attn_backend_context_manager)
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
# List of support backends for encoder/decoder models
|
# List of support backends for encoder/decoder models
|
||||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||||
|
|
||||||
HEAD_SIZES = [64, 256]
|
HEAD_SIZES = [64, 256]
|
||||||
|
|
||||||
NUM_HEADS = [1, 16]
|
NUM_HEADS = [1, 16]
|
||||||
@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
|||||||
test_pt.num_heads,
|
test_pt.num_heads,
|
||||||
test_pt.head_size,
|
test_pt.head_size,
|
||||||
test_pt.block_size,
|
test_pt.block_size,
|
||||||
device=CUDA_DEVICE)
|
device=CUDA_DEVICE,
|
||||||
|
backend=test_pt.backend_name)
|
||||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
@ -592,6 +593,7 @@ def _run_encoder_attention_test(
|
|||||||
attn: Attention,
|
attn: Attention,
|
||||||
encoder_test_params: PhaseTestParameters,
|
encoder_test_params: PhaseTestParameters,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
test_pt: TestPoint,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run encoder attention.
|
Run encoder attention.
|
||||||
@ -610,6 +612,8 @@ def _run_encoder_attention_test(
|
|||||||
(number_of_tokens x num_heads x head_size)
|
(number_of_tokens x num_heads x head_size)
|
||||||
query/key/value fields
|
query/key/value fields
|
||||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||||
|
* test_pt: The TestPoint object containing test details like number of
|
||||||
|
model heads, head size, name of the backend being used etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* Attention.forward() applied to packed {query,key,value} and
|
* Attention.forward() applied to packed {query,key,value} and
|
||||||
@ -619,20 +623,31 @@ def _run_encoder_attention_test(
|
|||||||
attn_type = AttentionType.ENCODER
|
attn_type = AttentionType.ENCODER
|
||||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
return attn.forward(packed_qkv.query,
|
with set_forward_context(attn_metadata):
|
||||||
packed_qkv.key,
|
# In the test setup the shape of the query is
|
||||||
packed_qkv.value,
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
torch.tensor([],
|
# the attention backend expect the shape to be
|
||||||
dtype=torch.float32,
|
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||||
device=packed_qkv.query.device),
|
# invoking the forward method.
|
||||||
attn_metadata,
|
# TODO - Update the way we construct the query so that it
|
||||||
attn_type=attn_type)
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
|
reshaped_query = packed_qkv.query.view(
|
||||||
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
|
return attn.forward(reshaped_query,
|
||||||
|
packed_qkv.key,
|
||||||
|
packed_qkv.value,
|
||||||
|
torch.tensor([],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=packed_qkv.query.device),
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
def _run_decoder_self_attention_test(
|
def _run_decoder_self_attention_test(
|
||||||
test_rsrcs: TestResources,
|
test_rsrcs: TestResources,
|
||||||
decoder_test_params: PhaseTestParameters,
|
decoder_test_params: PhaseTestParameters,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
test_pt: TestPoint,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run decoder self-attention test.
|
Run decoder self-attention test.
|
||||||
@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
|
|||||||
query/key/value fields
|
query/key/value fields
|
||||||
* attn_metadata: attention metadata for decoder-self attention
|
* attn_metadata: attention metadata for decoder-self attention
|
||||||
(contains KV cache memory-mapping)
|
(contains KV cache memory-mapping)
|
||||||
|
* test_pt: The TestPoint object containing test details like number of
|
||||||
|
model heads, head size, name of the backend being used etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||||
@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
|
|||||||
kv_cache = test_rsrcs.kv_cache
|
kv_cache = test_rsrcs.kv_cache
|
||||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
return attn.forward(packed_qkv.query,
|
with set_forward_context(attn_metadata):
|
||||||
packed_qkv.key,
|
# In the test setup the shape of the query is
|
||||||
packed_qkv.value,
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
kv_cache,
|
# the attention backend expect the shape to be
|
||||||
attn_metadata,
|
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||||
attn_type=attn_type)
|
# invoking the forward method.
|
||||||
|
# TODO - Update the way we construct the query so that it
|
||||||
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
|
reshaped_query = packed_qkv.query.view(
|
||||||
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
|
return attn.forward(reshaped_query,
|
||||||
|
packed_qkv.key,
|
||||||
|
packed_qkv.value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
def _run_encoder_decoder_cross_attention_test(
|
def _run_encoder_decoder_cross_attention_test(
|
||||||
@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
decoder_test_params: PhaseTestParameters,
|
decoder_test_params: PhaseTestParameters,
|
||||||
cross_test_params: Optional[PhaseTestParameters],
|
cross_test_params: Optional[PhaseTestParameters],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
test_pt: TestPoint,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run encoder/decoder cross-attention test.
|
Run encoder/decoder cross-attention test.
|
||||||
@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
(number_of_tokens x num_heads x head_size)
|
(number_of_tokens x num_heads x head_size)
|
||||||
key/value fields
|
key/value fields
|
||||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||||
|
* test_pt: The TestPoint object containing test details like number of
|
||||||
|
model heads, head size, name of the backend being used etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||||
@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||||
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
with set_forward_context(attn_metadata):
|
||||||
key,
|
# In the test setup the shape of the query is
|
||||||
value,
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
kv_cache,
|
# the attention backend expect the shape to be
|
||||||
attn_metadata,
|
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||||
attn_type=attn_type)
|
# invoking the forward method.
|
||||||
|
# TODO - Update the way we construct the query so that it
|
||||||
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
|
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||||
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
|
return attn.forward(reshaped_query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def set_reset_environment(attn_backend):
|
||||||
|
# Set the default torch datatype to bfloat16 to enable
|
||||||
|
# testing of the Flash Attention backend. Also clear the
|
||||||
|
# cached value of the backend.
|
||||||
|
default_dtype = torch.get_default_dtype()
|
||||||
|
if attn_backend.name == 'FLASH_ATTN':
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
get_attn_backend.cache_clear()
|
||||||
|
yield
|
||||||
|
# Reset the torch datatype to what it was before the test
|
||||||
|
# so as not to impact the remaining tests.
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
@ -773,10 +828,8 @@ def test_encoder_only(
|
|||||||
* max_dec_seq_len: max length of decoder input sequences
|
* max_dec_seq_len: max length of decoder input sequences
|
||||||
* max_enc_seq_len: max length of encoder input sequences
|
* max_enc_seq_len: max length of encoder input sequences
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Force Attention wrapper backend
|
# Force Attention wrapper backend
|
||||||
with global_force_attn_backend_context_manager(attn_backend):
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
|
|
||||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
# to be more than necessary, since exceeding the kv cache size
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
# is not part of this test
|
# is not part of this test
|
||||||
@ -807,10 +860,14 @@ def test_encoder_only(
|
|||||||
# PREFILL: encoder attention
|
# PREFILL: encoder attention
|
||||||
|
|
||||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
test_rsrcs.attn,
|
||||||
|
enc_test_params,
|
||||||
|
prephase_attn_metadata,
|
||||||
|
test_pt=test_pt))
|
||||||
|
|
||||||
# - Is encoder attention result correct?
|
# - Is encoder attention result correct?
|
||||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
* max_dec_seq_len: max length of decoder input sequences
|
* max_dec_seq_len: max length of decoder input sequences
|
||||||
* max_enc_seq_len: max length of encoder input sequences
|
* max_enc_seq_len: max length of encoder input sequences
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Force Attention wrapper backend
|
# Force Attention wrapper backend
|
||||||
with global_force_attn_backend_context_manager(attn_backend):
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
|
|
||||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
# to be more than necessary, since exceeding the kv cache size
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
# is not part of this test
|
# is not part of this test
|
||||||
@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
|
|||||||
|
|
||||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||||
enc_test_params,
|
enc_test_params,
|
||||||
prephase_attn_metadata)
|
prephase_attn_metadata,
|
||||||
|
test_pt=test_pt)
|
||||||
|
|
||||||
# - Is encoder attention result correct?
|
# - Is encoder attention result correct?
|
||||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|
||||||
# PREFILL: decoder self-attention test
|
# PREFILL: decoder self-attention test
|
||||||
|
|
||||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
test_rsrcs,
|
||||||
|
prephase_dec_test_params,
|
||||||
|
prephase_attn_metadata,
|
||||||
|
test_pt=test_pt)
|
||||||
|
|
||||||
# - Is prefill decoder self-attention correct?
|
# - Is prefill decoder self-attention correct?
|
||||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||||
prephase_dec_pckd_act_out)
|
prephase_dec_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|
||||||
# PREFILL: encoder/decoder cross-attention test
|
# PREFILL: encoder/decoder cross-attention test
|
||||||
|
|
||||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
test_rsrcs,
|
||||||
prephase_attn_metadata)
|
prephase_dec_test_params,
|
||||||
|
prephase_cross_test_params,
|
||||||
|
prephase_attn_metadata,
|
||||||
|
test_pt=test_pt)
|
||||||
|
|
||||||
# - Is prefill encoder/decoder cross-attention correct?
|
# - Is prefill encoder/decoder cross-attention correct?
|
||||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||||
prephase_cross_pckd_act_out)
|
prephase_cross_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|
||||||
# DECODE: build decode-phase attention metadata
|
# DECODE: build decode-phase attention metadata
|
||||||
|
|
||||||
@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
|
|||||||
# DECODE: decoder self-attention test
|
# DECODE: decoder self-attention test
|
||||||
|
|
||||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
test_rsrcs,
|
||||||
|
decphase_dec_test_params,
|
||||||
|
decphase_attn_metadata,
|
||||||
|
test_pt=test_pt)
|
||||||
|
|
||||||
# - Is decode-phase decoder self-attention correct?
|
# - Is decode-phase decoder self-attention correct?
|
||||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||||
decphase_dec_pckd_act_out)
|
decphase_dec_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|
||||||
# DECODE: encoder/decoder cross-attention test
|
# DECODE: encoder/decoder cross-attention test
|
||||||
|
|
||||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
test_rsrcs,
|
||||||
|
decphase_dec_test_params,
|
||||||
|
None,
|
||||||
|
decphase_attn_metadata,
|
||||||
|
test_pt=test_pt)
|
||||||
|
|
||||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||||
decphase_cross_pckd_act_out)
|
decphase_cross_pckd_act_out,
|
||||||
|
attn_backend.name)
|
||||||
|
|||||||
@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
|
|||||||
|
|
||||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||||
make_tensor_with_pad)
|
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||||
|
|
||||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||||
# bugs related to this test in PyTorch 2.4.
|
# bugs related to this test in PyTorch 2.4.
|
||||||
@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
|||||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||||
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
||||||
from vllm.attention.backends.xformers import XFormersBackend
|
from vllm.attention.backends.xformers import XFormersBackend
|
||||||
|
|
||||||
return XFormersBackend()
|
return XFormersBackend()
|
||||||
|
elif backend_name == STR_FLASH_ATTN_VAL:
|
||||||
|
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
|
return FlashAttentionBackend()
|
||||||
|
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Unrecognized backend_name {backend_name} for unit test")
|
f"Unrecognized backend_name {backend_name} for unit test")
|
||||||
|
|
||||||
|
|
||||||
def _make_metadata_tensors(
|
def _make_metadata_tensors(
|
||||||
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
|
seq_lens: Optional[List[int]],
|
||||||
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
|
context_lens: Optional[List[int]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
|
encoder_seq_lens: Optional[List[int]],
|
||||||
torch.Tensor, Optional[int]]:
|
device: Union[torch.device, str],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||||
|
torch.Tensor, torch.Tensor, Optional[int]]:
|
||||||
'''
|
'''
|
||||||
Build scalar & tensor values required to build attention metadata structure.
|
Build scalar & tensor values required to build attention metadata structure.
|
||||||
|
|
||||||
@ -553,6 +558,8 @@ def _make_metadata_tensors(
|
|||||||
* max_context_len: max(context_lens)
|
* max_context_len: max(context_lens)
|
||||||
* max_seq_len: max(seq_lens)
|
* max_seq_len: max(seq_lens)
|
||||||
* seq_start_loc: start idx of each sequence
|
* seq_start_loc: start idx of each sequence
|
||||||
|
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
|
||||||
|
* encoder_seq_start_loc: start idx of each encoder sequence
|
||||||
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
||||||
'''
|
'''
|
||||||
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
||||||
@ -566,8 +573,26 @@ def _make_metadata_tensors(
|
|||||||
|
|
||||||
seq_start_loc = None
|
seq_start_loc = None
|
||||||
|
|
||||||
|
if seq_lens_tensor is not None:
|
||||||
|
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=seq_lens_tensor.device)
|
||||||
|
torch.cumsum(seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=seq_start_loc.dtype,
|
||||||
|
out=seq_start_loc[1:])
|
||||||
|
|
||||||
|
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=encoder_seq_lens_tensor.device)
|
||||||
|
torch.cumsum(encoder_seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=encoder_seq_start_loc.dtype,
|
||||||
|
out=encoder_seq_start_loc[1:])
|
||||||
|
|
||||||
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
||||||
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
|
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
|
||||||
|
max_encoder_seq_len)
|
||||||
|
|
||||||
|
|
||||||
def make_kv_cache(num_blocks: int,
|
def make_kv_cache(num_blocks: int,
|
||||||
@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
|
backend: str,
|
||||||
default_val: float = 0.0) -> torch.Tensor:
|
default_val: float = 0.0) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Create a fake KV cache.
|
Create a fake KV cache.
|
||||||
@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||||
|
* for backend 'XFORMERS'
|
||||||
|
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||||
|
* for backend 'FLASH_ATTN'
|
||||||
'''
|
'''
|
||||||
|
if backend == 'XFORMERS':
|
||||||
kv_cache = torch.rand(
|
kv_cache = torch.rand(
|
||||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||||
|
elif backend == 'FLASH_ATTN':
|
||||||
|
kv_cache = torch.rand(
|
||||||
|
(2, num_blocks, block_size, num_heads, head_size)).to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||||
|
f"'FLASH_ATTN'.")
|
||||||
if default_val is not None:
|
if default_val is not None:
|
||||||
kv_cache[:, :, :] = default_val
|
kv_cache[:, :, :] = default_val
|
||||||
return kv_cache
|
return kv_cache
|
||||||
@ -858,8 +894,9 @@ def make_test_metadata(
|
|||||||
context_lens_tensor,
|
context_lens_tensor,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
_,
|
seq_start_loc,
|
||||||
encoder_seq_lens_tensor,
|
encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc,
|
||||||
max_encoder_seq_len,
|
max_encoder_seq_len,
|
||||||
) = _make_metadata_tensors(seq_lens,
|
) = _make_metadata_tensors(seq_lens,
|
||||||
context_lens,
|
context_lens,
|
||||||
@ -874,6 +911,7 @@ def make_test_metadata(
|
|||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
seq_start_loc=seq_start_loc,
|
||||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
@ -882,6 +920,7 @@ def make_test_metadata(
|
|||||||
num_encoder_tokens=num_encoder_tokens,
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||||
max_encoder_seq_len=max_encoder_seq_len,
|
max_encoder_seq_len=max_encoder_seq_len,
|
||||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||||
cross_kv_mmap.slot_mapping),
|
cross_kv_mmap.slot_mapping),
|
||||||
@ -904,8 +943,9 @@ def make_test_metadata(
|
|||||||
context_lens_tensor,
|
context_lens_tensor,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
_,
|
seq_start_loc,
|
||||||
encoder_seq_lens_tensor,
|
encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc,
|
||||||
max_encoder_seq_len,
|
max_encoder_seq_len,
|
||||||
) = _make_metadata_tensors(seq_lens,
|
) = _make_metadata_tensors(seq_lens,
|
||||||
context_lens,
|
context_lens,
|
||||||
@ -920,14 +960,17 @@ def make_test_metadata(
|
|||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
seq_start_loc=seq_start_loc,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=max(seq_lens),
|
max_decode_seq_len=max(seq_lens),
|
||||||
|
max_decode_query_len=1,
|
||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=kv_mmap.block_tables,
|
block_tables=kv_mmap.block_tables,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
num_encoder_tokens=num_encoder_tokens,
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||||
max_encoder_seq_len=max_encoder_seq_len,
|
max_encoder_seq_len=max_encoder_seq_len,
|
||||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||||
cross_kv_mmap.slot_mapping),
|
cross_kv_mmap.slot_mapping),
|
||||||
@ -936,7 +979,8 @@ def make_test_metadata(
|
|||||||
|
|
||||||
|
|
||||||
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||||
output_under_test: torch.Tensor) -> None:
|
output_under_test: torch.Tensor,
|
||||||
|
backend: str) -> None:
|
||||||
'''
|
'''
|
||||||
Assert that observed output matches the ideal output
|
Assert that observed output matches the ideal output
|
||||||
contained in the test parameters data structure.
|
contained in the test parameters data structure.
|
||||||
@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
|||||||
* output_under_test: actually observed output value
|
* output_under_test: actually observed output value
|
||||||
'''
|
'''
|
||||||
ideal_output = test_params.packed_qkvo.ideal_output
|
ideal_output = test_params.packed_qkvo.ideal_output
|
||||||
torch.testing.assert_close(ideal_output,
|
if backend == 'XFORMERS':
|
||||||
output_under_test.view_as(ideal_output))
|
torch.testing.assert_close(ideal_output,
|
||||||
|
output_under_test.view_as(ideal_output))
|
||||||
|
|
||||||
|
elif backend == 'FLASH_ATTN':
|
||||||
|
# For FlashAttention override the accuracy thresholds to non default
|
||||||
|
# values since we notice a higher difference between the ideal and
|
||||||
|
# actual output.
|
||||||
|
torch.testing.assert_close(ideal_output,
|
||||||
|
output_under_test.view_as(ideal_output),
|
||||||
|
atol=0.01,
|
||||||
|
rtol=0.016)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||||
|
f"'FLASH_ATTN'.")
|
||||||
|
|
||||||
|
|
||||||
# Copied/modified from torch._refs.__init__.py
|
# Copied/modified from torch._refs.__init__.py
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def run_test(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
|
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
|
||||||
|
|||||||
@ -10,10 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
from vllm.attention.backends.utils import (
|
||||||
compute_slot_mapping,
|
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||||
is_block_tables_empty)
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||||
@ -73,7 +74,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
src_key_cache = src_kv_cache[0]
|
src_key_cache = src_kv_cache[0]
|
||||||
dst_key_cache = dst_kv_cache[0]
|
dst_key_cache = dst_kv_cache[0]
|
||||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||||
|
|
||||||
src_value_cache = src_kv_cache[1]
|
src_value_cache = src_kv_cache[1]
|
||||||
dst_value_cache = dst_kv_cache[1]
|
dst_value_cache = dst_kv_cache[1]
|
||||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||||
@ -85,6 +85,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
) -> None:
|
) -> None:
|
||||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||||
|
|
||||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
# |-------------------- seq_len ---------------------|
|
# |-------------------- seq_len ---------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch.
|
|
||||||
max_query_len: Optional[int]
|
|
||||||
|
|
||||||
# Max number of query tokens among request in the batch.
|
|
||||||
max_decode_query_len: Optional[int]
|
|
||||||
|
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
max_prefill_seq_len: int
|
max_prefill_seq_len: int
|
||||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
# requests only.
|
# requests only.
|
||||||
max_decode_seq_len: int
|
max_decode_seq_len: int
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
|
||||||
# is [4, 6], it is [0, 4, 10].
|
|
||||||
query_start_loc: Optional[torch.Tensor]
|
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
||||||
# [4, 6], it is [0, 4, 10].
|
|
||||||
seq_start_loc: Optional[torch.Tensor]
|
|
||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
# so far).
|
# so far).
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
# Whether or not if cuda graph is enabled.
|
# Whether or not if cuda graph is enabled.
|
||||||
# Cuda-graph is currently enabled for decoding only.
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
|
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
# Maximum query length in the batch.
|
||||||
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# Max number of query tokens among request in the batch.
|
||||||
|
max_decode_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
query_start_loc: Optional[torch.Tensor] = None
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||||
|
|
||||||
|
# Begin encoder attn & enc/dec cross-attn fields...
|
||||||
|
|
||||||
|
# Encoder sequence lengths representation
|
||||||
|
encoder_seq_lens: Optional[List[int]] = None
|
||||||
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
# Maximum sequence length among encoder sequences
|
||||||
|
max_encoder_seq_len: Optional[int] = None
|
||||||
|
# Number of tokens input to encoder
|
||||||
|
num_encoder_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
# Cross-attention memory-mapping data structures: slot mapping
|
||||||
|
# and block tables
|
||||||
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||||
|
cross_block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_encoder_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for encoder attention is set.
|
||||||
|
'''
|
||||||
|
return is_all_encoder_attn_metadata_set(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_cross_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for enc/dec cross-attention is set.
|
||||||
|
|
||||||
|
Superset of encoder attention required metadata.
|
||||||
|
'''
|
||||||
|
return is_all_cross_attn_metadata_set(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||||
if self.num_prefills == 0:
|
if self.num_prefills == 0:
|
||||||
@ -159,32 +197,52 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
if self._cached_prefill_metadata is not None:
|
if self._cached_prefill_metadata is not None:
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
assert ((self.seq_lens is not None)
|
||||||
assert self.seq_lens_tensor is not None
|
or (self.encoder_seq_lens is not None))
|
||||||
assert self.query_start_loc is not None
|
assert ((self.seq_lens_tensor is not None)
|
||||||
assert self.context_lens_tensor is not None
|
or (self.encoder_seq_lens_tensor is not None))
|
||||||
assert self.block_tables is not None
|
|
||||||
assert self.seq_start_loc is not None
|
# Compute some attn_metadata fields which default to None
|
||||||
|
query_start_loc = (None if self.query_start_loc is None else
|
||||||
|
self.query_start_loc[:self.num_prefills + 1])
|
||||||
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
|
self.slot_mapping[:self.num_prefill_tokens])
|
||||||
|
seq_lens = (None if self.seq_lens is None else
|
||||||
|
self.seq_lens[:self.num_prefills])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[:self.num_prefills])
|
||||||
|
seq_start_loc = (None if self.seq_start_loc is None else
|
||||||
|
self.seq_start_loc[:self.num_prefills + 1])
|
||||||
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||||
|
self.context_lens_tensor[:self.num_prefills])
|
||||||
|
block_tables = (None if self.block_tables is None else
|
||||||
|
self.block_tables[:self.num_prefills])
|
||||||
|
|
||||||
self._cached_prefill_metadata = FlashAttentionMetadata(
|
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=self.
|
multi_modal_placeholder_index_maps=self.
|
||||||
multi_modal_placeholder_index_maps,
|
multi_modal_placeholder_index_maps,
|
||||||
seq_lens=self.seq_lens[:self.num_prefills],
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
max_decode_query_len=0,
|
max_decode_query_len=0,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
query_start_loc=query_start_loc,
|
||||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
seq_start_loc=seq_start_loc,
|
||||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=self.block_tables[:self.num_prefills],
|
block_tables=block_tables,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
)
|
# Begin encoder & cross attn fields below...
|
||||||
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||||
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -194,17 +252,25 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
if self._cached_decode_metadata is not None:
|
if self._cached_decode_metadata is not None:
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
assert self.block_tables is not None
|
assert ((self.seq_lens_tensor is not None)
|
||||||
assert self.seq_lens_tensor is not None
|
or (self.encoder_seq_lens_tensor is not None))
|
||||||
|
|
||||||
|
# Compute some attn_metadata fields which default to None
|
||||||
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
|
self.slot_mapping[self.num_prefill_tokens:])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[self.num_prefills:])
|
||||||
|
block_tables = (None if self.block_tables is None else
|
||||||
|
self.block_tables[self.num_prefills:])
|
||||||
|
|
||||||
self._cached_decode_metadata = FlashAttentionMetadata(
|
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_decode_query_len=self.max_decode_query_len,
|
max_decode_query_len=self.max_decode_query_len,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
@ -214,9 +280,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||||
if self.seq_start_loc is not None else None,
|
if self.seq_start_loc is not None else None,
|
||||||
context_lens_tensor=None,
|
context_lens_tensor=None,
|
||||||
block_tables=self.block_tables[self.num_prefills:],
|
block_tables=block_tables,
|
||||||
use_cuda_graph=self.use_cuda_graph,
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
)
|
# Begin encoder & cross attn fields below...
|
||||||
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||||
|
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||||
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||||
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
def advance_step(self,
|
||||||
@ -586,16 +658,20 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"FlashAttentionImpl")
|
|
||||||
|
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
|
if (attn_type == AttentionType.ENCODER
|
||||||
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
|
raise AttributeError("Encoder attention requires setting "
|
||||||
|
"encoder metadata attributes.")
|
||||||
|
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||||
|
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||||
|
raise AttributeError("Encoder/decoder cross-attention "
|
||||||
|
"requires setting cross-attention "
|
||||||
|
"metadata attributes.")
|
||||||
|
|
||||||
output = torch.ops.vllm.unified_flash_attention(
|
output = torch.ops.vllm.unified_flash_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -608,6 +684,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
attn_type.value,
|
||||||
self.sliding_window,
|
self.sliding_window,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.logits_soft_cap,
|
self.logits_soft_cap,
|
||||||
@ -616,6 +693,89 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _get_query_key_seq_metadata(
|
||||||
|
attn_metadata,
|
||||||
|
is_prompt: bool,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Returns sequence metadata for key and query based on the specified
|
||||||
|
attention type and whether input is a prompt.
|
||||||
|
|
||||||
|
This function computes the starting locations and maximum sequence lengths
|
||||||
|
for key and query sequences for different attention types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_metadata: The attention metadata object
|
||||||
|
is_prompt (bool): A flag indicating if the input is a prompt
|
||||||
|
attn_type (AttentionType): The type of attention being used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing four integers:
|
||||||
|
- Starting location for the query sequence.
|
||||||
|
- Maximum sequence length for the query sequence.
|
||||||
|
- Starting location for the key sequence.
|
||||||
|
- Maximum sequence length for the key sequence.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If an invalid attention type is provided.
|
||||||
|
"""
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Decoder self-attention
|
||||||
|
# Choose max_seq_len based on whether we are in prompt_run
|
||||||
|
if is_prompt:
|
||||||
|
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||||
|
else:
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||||
|
attn_metadata.seq_start_loc, max_seq_len)
|
||||||
|
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# This is cross attention between the where the key
|
||||||
|
# is the precomputed encoder attention and query
|
||||||
|
# is the input sequence.
|
||||||
|
# Choose query max length based on whether it is prompt
|
||||||
|
# or not.
|
||||||
|
if is_prompt:
|
||||||
|
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||||
|
else:
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||||
|
attn_metadata.encoder_seq_start_loc,
|
||||||
|
attn_metadata.max_encoder_seq_len)
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
# For encoder attention both the query and the key are same i.e the
|
||||||
|
# encoder sequence.
|
||||||
|
return (attn_metadata.encoder_seq_start_loc,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.encoder_seq_start_loc,
|
||||||
|
attn_metadata.max_encoder_seq_len)
|
||||||
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
assert is_prompt, "Should not have decode for encoder only model."
|
||||||
|
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
|
||||||
|
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_causal_option(attn_type: AttentionType) -> bool:
|
||||||
|
"""
|
||||||
|
Determine whether the given attention type is suitable for causal
|
||||||
|
attention mechanisms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_type (AttentionType): The type of attention being evaluated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Returns `True` if the attention type is suitable for causal
|
||||||
|
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||||
|
otherwise returns `False`.
|
||||||
|
"""
|
||||||
|
return not (attn_type == AttentionType.ENCODER
|
||||||
|
or attn_type == AttentionType.ENCODER_ONLY
|
||||||
|
or attn_type == AttentionType.ENCODER_DECODER)
|
||||||
|
|
||||||
|
|
||||||
def unified_flash_attention(
|
def unified_flash_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
@ -628,60 +788,76 @@ def unified_flash_attention(
|
|||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_scale: float,
|
v_scale: float,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
|
attn_type_int_val: int,
|
||||||
window_size: Optional[List[int]] = None,
|
window_size: Optional[List[int]] = None,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Convert integer attn_type to enum
|
||||||
|
try:
|
||||||
|
attn_type = AttentionType(attn_type_int_val)
|
||||||
|
except ValueError as err:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||||
|
|
||||||
current_metadata = get_forward_context()
|
current_metadata = get_forward_context()
|
||||||
assert current_metadata is not None
|
assert current_metadata is not None
|
||||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, num_heads, head_size)
|
query = query.view(-1, num_heads, head_size)
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
if (key is not None) and (value is not None):
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
|
# We skip updating the KV cache under two conditions:
|
||||||
|
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||||
|
# only the encoder attention without updating the cache.
|
||||||
|
# b. When both Key and Value are None. This occurs during
|
||||||
|
# cross-attention computation in the decoding phase, where the KV
|
||||||
|
# cache is already populated with the cross-attention tensor.
|
||||||
|
# Thus, we skip cache updates during this time.
|
||||||
|
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||||
|
value is not None):
|
||||||
|
if attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Update cross-attention KV cache (prefill-only)
|
||||||
|
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
else:
|
||||||
|
# Update self-attention KV cache (prefill/decode)
|
||||||
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory profiling run.
|
# not cached. This happens during the initial memory profiling run.
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
attn_metadata.slot_mapping.flatten(),
|
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_query_tokens) = \
|
||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
decode_query = query[num_prefill_query_tokens:]
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
|
||||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
|
||||||
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
|
||||||
decode_query = query[num_prefill_tokens:]
|
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
query = query[:num_prefill_tokens]
|
query = query[:num_prefill_query_tokens]
|
||||||
key = key[:num_prefill_tokens]
|
assert query.shape[0] == num_prefill_query_tokens
|
||||||
value = value[:num_prefill_tokens]
|
assert decode_query.shape[0] == num_decode_query_tokens
|
||||||
|
|
||||||
assert query.shape[0] == num_prefill_tokens
|
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
|
||||||
|
|
||||||
prefill_output: Optional[torch.Tensor] = None
|
prefill_output: Optional[torch.Tensor] = None
|
||||||
decode_output: Optional[torch.Tensor] = None
|
decode_output: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||||
@ -689,22 +865,30 @@ def unified_flash_attention(
|
|||||||
# normal attention
|
# normal attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
|
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||||
|
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||||
|
|
||||||
|
key = key[:num_prefill_kv_tokens]
|
||||||
|
value = value[:num_prefill_kv_tokens]
|
||||||
|
|
||||||
prefill_output = flash_attn_varlen_func(
|
prefill_output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
cu_seqlens_q=q_seq_start_loc,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=k_seq_start_loc,
|
||||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
max_seqlen_q=q_seq_len,
|
||||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
max_seqlen_k=k_seq_len,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=_get_causal_option(attn_type),
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
|
assert attn_type == AttentionType.DECODER, (
|
||||||
|
"Only decoder-only models support prefix caching")
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
prefill_output = flash_attn_varlen_func( # noqa
|
prefill_output = flash_attn_varlen_func( # noqa
|
||||||
@ -729,6 +913,8 @@ def unified_flash_attention(
|
|||||||
# because different queries might have different lengths.
|
# because different queries might have different lengths.
|
||||||
assert decode_meta.max_decode_query_len is not None
|
assert decode_meta.max_decode_query_len is not None
|
||||||
if decode_meta.max_decode_query_len > 1:
|
if decode_meta.max_decode_query_len > 1:
|
||||||
|
assert attn_type == AttentionType.DECODER, (
|
||||||
|
"Only decoder-only models support max_decode_query_len > 1")
|
||||||
decode_output = flash_attn_varlen_func(
|
decode_output = flash_attn_varlen_func(
|
||||||
q=decode_query,
|
q=decode_query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
@ -746,12 +932,17 @@ def unified_flash_attention(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use flash_attn_with_kvcache for normal decoding.
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
|
(
|
||||||
|
seq_lens_arg,
|
||||||
|
_,
|
||||||
|
block_tables_arg,
|
||||||
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
decode_output = flash_attn_with_kvcache(
|
decode_output = flash_attn_with_kvcache(
|
||||||
q=decode_query.unsqueeze(1),
|
q=decode_query.unsqueeze(1),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=block_tables_arg,
|
||||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
cache_seqlens=seq_lens_arg,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
@ -761,10 +952,10 @@ def unified_flash_attention(
|
|||||||
|
|
||||||
if prefill_output is None:
|
if prefill_output is None:
|
||||||
assert decode_output is not None
|
assert decode_output is not None
|
||||||
return decode_output.view(num_decode_tokens, hidden_size)
|
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||||
if decode_output is None:
|
if decode_output is None:
|
||||||
assert prefill_output is not None
|
assert prefill_output is not None
|
||||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||||
|
|
||||||
# Chunked prefill does not work with speculative decoding.
|
# Chunked prefill does not work with speculative decoding.
|
||||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||||
@ -786,6 +977,7 @@ def unified_flash_attention_fake(
|
|||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_scale: float,
|
v_scale: float,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
|
attn_type_int_val: int,
|
||||||
window_size: Optional[List[int]] = None,
|
window_size: Optional[List[int]] = None,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||||
AttentionState)
|
AttentionState)
|
||||||
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
|
|||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
)
|
)
|
||||||
if is_encoder_decoder_model:
|
if is_encoder_decoder_model:
|
||||||
# The encoder decoder model works only with XFormers backend.
|
# The encoder decoder model works only with XFormers and
|
||||||
# Assert the same.
|
# Flash Attention backend. Assert the same.
|
||||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
assert self.runner.attn_backend.get_name() in\
|
||||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
["XFORMERS", "FLASH_ATTN"], \
|
||||||
f" got '{self.runner.attn_backend.get_name()}'"
|
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||||
|
f"'FLASH_ATTN', but "\
|
||||||
|
f"got '{self.runner.attn_backend.get_name()}'"
|
||||||
self._update_captured_metadata_for_enc_dec_model(
|
self._update_captured_metadata_for_enc_dec_model(
|
||||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||||
|
|
||||||
@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
|
|||||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||||
}
|
}
|
||||||
if is_encoder_decoder_model:
|
if is_encoder_decoder_model:
|
||||||
# The encoder decoder model works only with XFormers backend.
|
# The encoder decoder model works only with XFormers and
|
||||||
# Assert the same.
|
# Flash Attention backend. Assert the same.
|
||||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
assert self.runner.attn_backend.get_name() in\
|
||||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
["XFORMERS", "FLASH_ATTN"], \
|
||||||
f" got '{self.runner.attn_backend.get_name()}'"
|
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||||
|
f"'FLASH_ATTN', but "\
|
||||||
|
f"got '{self.runner.attn_backend.get_name()}'"
|
||||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||||
return input_buffers
|
return input_buffers
|
||||||
@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
|
|||||||
input_buffers["block_tables"].copy_(
|
input_buffers["block_tables"].copy_(
|
||||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||||
if is_encoder_decoder_model:
|
if is_encoder_decoder_model:
|
||||||
# The encoder decoder model works only with XFormers backend.
|
# The encoder decoder model works only with XFormers and
|
||||||
# Assert the same.
|
# Flash Attention backend. Assert the same.
|
||||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
assert self.runner.attn_backend.get_name() in\
|
||||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
["XFORMERS", "FLASH_ATTN"], \
|
||||||
f" got '{self.runner.attn_backend.get_name()}'"
|
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||||
|
f"'FLASH_ATTN', but "\
|
||||||
|
f"got '{self.runner.attn_backend.get_name()}'"
|
||||||
self._prepare_input_buffers_for_enc_dec_model(
|
self._prepare_input_buffers_for_enc_dec_model(
|
||||||
attn_metadata, input_buffers)
|
attn_metadata, input_buffers)
|
||||||
|
|
||||||
@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||||
|
attn_metadata.num_encoder_tokens = 0
|
||||||
|
|
||||||
def _add_additonal_input_buffers_for_enc_dec_model(
|
def _add_additonal_input_buffers_for_enc_dec_model(
|
||||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||||
@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
|
|||||||
input_buffers["cross_block_tables"].copy_(
|
input_buffers["cross_block_tables"].copy_(
|
||||||
attn_metadata.decode_metadata.cross_block_tables,
|
attn_metadata.decode_metadata.cross_block_tables,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_all_encoder_attn_metadata_set(attn_metadata):
|
||||||
|
'''
|
||||||
|
All attention metadata required for encoder attention is set.
|
||||||
|
'''
|
||||||
|
return ((attn_metadata.encoder_seq_lens is not None)
|
||||||
|
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
||||||
|
and (attn_metadata.max_encoder_seq_len is not None))
|
||||||
|
|
||||||
|
|
||||||
|
def is_all_cross_attn_metadata_set(attn_metadata):
|
||||||
|
'''
|
||||||
|
All attention metadata required for enc/dec cross-attention is set.
|
||||||
|
|
||||||
|
Superset of encoder attention required metadata.
|
||||||
|
'''
|
||||||
|
return (attn_metadata.is_all_encoder_attn_metadata_set
|
||||||
|
and (attn_metadata.cross_slot_mapping is not None)
|
||||||
|
and (attn_metadata.cross_block_tables is not None))
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_len_block_table_args(
|
||||||
|
attn_metadata,
|
||||||
|
is_prompt: bool,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> tuple:
|
||||||
|
'''
|
||||||
|
The particular choice of sequence-length- and block-table-related
|
||||||
|
attributes which should be extracted from attn_metadata is dependent
|
||||||
|
on the type of attention operation.
|
||||||
|
|
||||||
|
Decoder attn -> select entirely decoder self-attention-related fields
|
||||||
|
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||||
|
cross-attn block-tables fields
|
||||||
|
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention op
|
||||||
|
* is_prompt: True if prefill, False otherwise
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Appropriate sequence-lengths tensor
|
||||||
|
* Appropriate max sequence-length scalar
|
||||||
|
* Appropriate block tables (or None)
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Decoder self-attention
|
||||||
|
# Choose max_seq_len based on whether we are in prompt_run
|
||||||
|
if is_prompt:
|
||||||
|
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||||
|
else:
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||||
|
attn_metadata.block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||||
|
# cross-attention utilizes special "cross" block tables
|
||||||
|
return (attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.cross_block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
# No block tables associated with encoder attention
|
||||||
|
return (attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len, None)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_prefill_decode_query_kv_tokens(
|
||||||
|
attn_metadata,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> Tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Calculate the number of prefill and decode tokens for query, key/value
|
||||||
|
based on the attention metadata and the specified attention type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
|
||||||
|
attn_type (AttentionType): The type of attention being used.
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int, int]: A tuple containing three integers:
|
||||||
|
- The number of prefill query tokens.
|
||||||
|
- The number of prefill key/value tokens.
|
||||||
|
- The number of decode query tokens.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the number of encoder tokens in `attn_metadata`
|
||||||
|
is `None` when required for the calculations.
|
||||||
|
"""
|
||||||
|
num_prefill_query_tokens = 0
|
||||||
|
num_decode_query_tokens = 0
|
||||||
|
num_prefill_kv_tokens = 0
|
||||||
|
if attn_type == AttentionType.ENCODER:
|
||||||
|
# Encoder attention is only invoked during prefill phase.
|
||||||
|
# The same input servers a both query and key.
|
||||||
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
|
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_decode_query_tokens = 0
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
|
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
# The key is the encoder/cross-attention.
|
||||||
|
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||||
|
else: # attn_type == AttentionType.DECODER or
|
||||||
|
# attn_type == AttentionType.ENCODER_ONLY
|
||||||
|
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
|
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
|
num_decode_query_tokens)
|
||||||
|
|||||||
@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
from vllm.attention.backends.utils import (
|
||||||
CommonMetadataBuilder)
|
CommonAttentionState, CommonMetadataBuilder,
|
||||||
|
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||||
|
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Encoder sequence lengths representation
|
# Encoder sequence lengths representation
|
||||||
encoder_seq_lens: Optional[List[int]] = None
|
encoder_seq_lens: Optional[List[int]] = None
|
||||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
# FIXME: It is for flash attn.
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Maximum sequence length among encoder sequences
|
# Maximum sequence length among encoder sequences
|
||||||
max_encoder_seq_len: Optional[int] = None
|
max_encoder_seq_len: Optional[int] = None
|
||||||
@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
'''
|
'''
|
||||||
All attention metadata required for encoder attention is set.
|
All attention metadata required for encoder attention is set.
|
||||||
'''
|
'''
|
||||||
return ((self.encoder_seq_lens is not None)
|
return is_all_encoder_attn_metadata_set(self)
|
||||||
and (self.encoder_seq_lens_tensor is not None)
|
|
||||||
and (self.max_encoder_seq_len is not None))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_all_cross_attn_metadata_set(self):
|
def is_all_cross_attn_metadata_set(self):
|
||||||
@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
Superset of encoder attention required metadata.
|
Superset of encoder attention required metadata.
|
||||||
'''
|
'''
|
||||||
return (self.is_all_encoder_attn_metadata_set
|
return is_all_cross_attn_metadata_set(self)
|
||||||
and (self.cross_slot_mapping is not None)
|
|
||||||
and (self.cross_block_tables is not None))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||||
@ -329,64 +332,6 @@ def _set_attn_bias(
|
|||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
def _get_seq_len_block_table_args(
|
|
||||||
attn_metadata: XFormersMetadata,
|
|
||||||
is_prompt: bool,
|
|
||||||
attn_type: AttentionType,
|
|
||||||
) -> tuple:
|
|
||||||
'''
|
|
||||||
The particular choice of sequence-length- and block-table-related
|
|
||||||
attributes which should be extracted from attn_metadata is dependent
|
|
||||||
on the type of attention operation.
|
|
||||||
|
|
||||||
Decoder attn -> select entirely decoder self-attention-related fields
|
|
||||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
|
||||||
cross-attn block-tables fields
|
|
||||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* attn_metadata: Attention metadata structure associated with attention op
|
|
||||||
* is_prompt: True if prefill, False otherwise
|
|
||||||
* attn_type: encoder attention, decoder self-attention,
|
|
||||||
encoder/decoder cross-attention
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
* Appropriate sequence-lengths tensor
|
|
||||||
* Appropriate max sequence-length scalar
|
|
||||||
* Appropriate block tables (or None)
|
|
||||||
'''
|
|
||||||
|
|
||||||
if attn_type == AttentionType.DECODER:
|
|
||||||
# Decoder self-attention
|
|
||||||
# Choose max_seq_len based on whether we are in prompt_run
|
|
||||||
if is_prompt:
|
|
||||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
|
||||||
else:
|
|
||||||
max_seq_len = attn_metadata.max_decode_seq_len
|
|
||||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
|
||||||
attn_metadata.block_tables)
|
|
||||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
|
||||||
# cross-attention utilizes special "cross" block tables
|
|
||||||
return (attn_metadata.encoder_seq_lens_tensor,
|
|
||||||
attn_metadata.max_encoder_seq_len,
|
|
||||||
attn_metadata.cross_block_tables)
|
|
||||||
elif attn_type == AttentionType.ENCODER:
|
|
||||||
# No block tables associated with encoder attention
|
|
||||||
return (attn_metadata.encoder_seq_lens_tensor,
|
|
||||||
attn_metadata.max_encoder_seq_len, None)
|
|
||||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
assert is_prompt, "Should not have decode for encoder only model."
|
|
||||||
|
|
||||||
# No block tables associated with encoder attention
|
|
||||||
return (attn_metadata.seq_lens_tensor,
|
|
||||||
attn_metadata.max_prefill_seq_len, None)
|
|
||||||
else:
|
|
||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
|
||||||
|
|
||||||
|
|
||||||
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
||||||
|
|
||||||
_metadata_cls = XFormersMetadata
|
_metadata_cls = XFormersMetadata
|
||||||
@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
updated_slot_mapping,
|
updated_slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale, v_scale)
|
k_scale, v_scale)
|
||||||
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
if attn_type == AttentionType.ENCODER:
|
num_decode_query_tokens) = \
|
||||||
# Encoder attention - chunked prefill is not applicable;
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
# derive token-count from query shape & and treat them
|
|
||||||
# as 100% prefill tokens
|
|
||||||
assert attn_metadata.num_encoder_tokens is not None
|
|
||||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
num_decode_tokens = 0
|
|
||||||
elif attn_type == AttentionType.DECODER:
|
|
||||||
# Decoder self-attention supports chunked prefill.
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
# Only enforce this shape-constraint for decoder
|
|
||||||
# self-attention
|
|
||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
||||||
else: # attn_type == AttentionType.ENCODER_DECODER
|
|
||||||
# Encoder/decoder cross-attention requires no chunked
|
|
||||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
if attn_metadata.num_encoder_tokens is not None:
|
|
||||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
else:
|
|
||||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
decode_query = query[num_prefill_tokens:]
|
decode_query = query[num_prefill_query_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
query = query[:num_prefill_tokens]
|
query = query[:num_prefill_query_tokens]
|
||||||
if key is not None and value is not None:
|
if key is not None and value is not None:
|
||||||
key = key[:num_encoder_tokens]
|
key = key[:num_prefill_kv_tokens]
|
||||||
value = value[:num_encoder_tokens]
|
value = value[:num_prefill_kv_tokens]
|
||||||
|
|
||||||
assert query.shape[0] == num_prefill_tokens
|
assert query.shape[0] == num_prefill_query_tokens
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_query_tokens
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# prefix.
|
# prefix.
|
||||||
out = self._run_memory_efficient_xformers_forward(
|
out = self._run_memory_efficient_xformers_forward(
|
||||||
query, key, value, prefill_meta, attn_type=attn_type)
|
query, key, value, prefill_meta, attn_type=attn_type)
|
||||||
assert out.shape == output[:num_prefill_tokens].shape
|
assert out.shape == output[:num_prefill_query_tokens].shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_query_tokens] = out
|
||||||
else:
|
else:
|
||||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||||
"Encoder-only models should not have prefix attention.")
|
"Encoder-only models should not have prefix attention.")
|
||||||
@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_query_tokens] = out
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||||
@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
seq_lens_arg,
|
seq_lens_arg,
|
||||||
max_seq_len_arg,
|
max_seq_len_arg,
|
||||||
block_tables_arg,
|
block_tables_arg,
|
||||||
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
|
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
||||||
decode_query,
|
decode_query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
|||||||
@ -98,7 +98,6 @@ def get_attn_backend(
|
|||||||
is_blocksparse: bool = False,
|
is_blocksparse: bool = False,
|
||||||
) -> Type[AttentionBackend]:
|
) -> Type[AttentionBackend]:
|
||||||
"""Selects which attention backend to use and lazily imports it."""
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
|
||||||
if is_blocksparse:
|
if is_blocksparse:
|
||||||
logger.info("Using BlocksparseFlashAttention backend.")
|
logger.info("Using BlocksparseFlashAttention backend.")
|
||||||
from vllm.attention.backends.blocksparse_attn import (
|
from vllm.attention.backends.blocksparse_attn import (
|
||||||
@ -108,6 +107,7 @@ def get_attn_backend(
|
|||||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||||
is_attention_free)
|
is_attention_free)
|
||||||
if backend == _Backend.FLASH_ATTN:
|
if backend == _Backend.FLASH_ATTN:
|
||||||
|
logger.info("Using Flash Attention backend.")
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
return FlashAttentionBackend
|
return FlashAttentionBackend
|
||||||
|
|||||||
@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
|
|||||||
Decoder output torch.Tensor
|
Decoder output torch.Tensor
|
||||||
"""
|
"""
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
|
|
||||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
embed_pos = self.embed_positions(
|
embed_pos = self.embed_positions(
|
||||||
|
|||||||
@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
|||||||
"currently supported with encoder/"
|
"currently supported with encoder/"
|
||||||
"decoder models.")
|
"decoder models.")
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
|
||||||
"currently supported with encoder/"
|
"backends currently supported with encoder/"
|
||||||
"decoder models.")
|
"decoder models.")
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.utils import get_architecture_class_name
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
|
||||||
MultiModalRegistry)
|
MultiModalRegistry)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# The Mllama model has PagedAttention specific logic because of which it
|
||||||
|
# can only be run with the XFORMERS backend
|
||||||
|
# TODO Make Mllama model work with Flash Attention backend.
|
||||||
|
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||||
@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
models) but these arguments are present here for compatibility with
|
models) but these arguments are present here for compatibility with
|
||||||
the base-class constructor.
|
the base-class constructor.
|
||||||
'''
|
'''
|
||||||
|
self._maybe_force_supported_attention_backend(model_config)
|
||||||
self._maybe_force_supported_attention_backend()
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
# Crash for unsupported encoder/scenarios
|
# Crash for unsupported encoder/scenarios
|
||||||
assert_enc_dec_mr_supported_scenario(self)
|
assert_enc_dec_mr_supported_scenario(self)
|
||||||
|
|
||||||
def _maybe_force_supported_attention_backend(self):
|
def _is_xformers_only_encoder_decoder_model(self,
|
||||||
|
model: ModelConfig) -> bool:
|
||||||
|
return get_architecture_class_name(
|
||||||
|
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
|
||||||
|
|
||||||
|
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
|
||||||
'''
|
'''
|
||||||
Force vLLM to use the XFormers attention backend,
|
Force vLLM to use the XFormers attention backend,
|
||||||
which is currently the only supported option.
|
which is currently the only supported option.
|
||||||
@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
is_forced_by_global = maybe_global_forced_backend is not None
|
is_forced_by_global = maybe_global_forced_backend is not None
|
||||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||||
|
|
||||||
if not (is_forced_by_global or is_forced_by_env_var):
|
if not (is_forced_by_global or is_forced_by_env_var) \
|
||||||
|
and self._is_xformers_only_encoder_decoder_model(model):
|
||||||
# The user has not already specified an attention backend
|
# The user has not already specified an attention backend
|
||||||
# override
|
# override
|
||||||
logger.info("EncoderDecoderModelRunner requires "
|
logger.info(
|
||||||
"XFormers backend; overriding backend "
|
"Encoder-Decoder Model Architecture %s requires XFormers "
|
||||||
"auto-selection and forcing XFormers.")
|
"backend; overriding backend auto-selection and "
|
||||||
|
"forcing XFormers.", get_architecture_class_name(model))
|
||||||
global_force_attn_backend(_Backend.XFORMERS)
|
global_force_attn_backend(_Backend.XFORMERS)
|
||||||
elif is_forced_by_global:
|
elif is_forced_by_global:
|
||||||
# Backend override enforced by global variable takes
|
# Backend override enforced by global variable takes
|
||||||
# precedence over vLLM backend environment variable.
|
# precedence over vLLM backend environment variable.
|
||||||
if maybe_global_forced_backend != _Backend.XFORMERS:
|
if maybe_global_forced_backend not in\
|
||||||
|
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||||
raise_backend_err()
|
raise_backend_err()
|
||||||
elif is_forced_by_env_var:
|
elif is_forced_by_env_var:
|
||||||
# Backend override enforced by vLLM backend
|
# Backend override enforced by vLLM backend
|
||||||
# environment variable
|
# environment variable
|
||||||
if maybe_env_var_forced_backend != _Backend.XFORMERS:
|
if maybe_env_var_forced_backend not in\
|
||||||
|
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||||
raise_backend_err()
|
raise_backend_err()
|
||||||
|
|
||||||
def _list_to_int32_tensor(
|
def _list_to_int32_tensor(
|
||||||
@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
attn_metadata.encoder_seq_lens,
|
attn_metadata.encoder_seq_lens,
|
||||||
attn_metadata.encoder_seq_lens_tensor,
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
attn_metadata.max_encoder_seq_len,
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.encoder_seq_start_loc,
|
||||||
attn_metadata.cross_slot_mapping,
|
attn_metadata.cross_slot_mapping,
|
||||||
attn_metadata.cross_block_tables,
|
attn_metadata.cross_block_tables,
|
||||||
) = (
|
) = (
|
||||||
@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
encoder_seq_lens,
|
encoder_seq_lens,
|
||||||
encoder_seq_lens_tensor,
|
encoder_seq_lens_tensor,
|
||||||
max_encoder_seq_len,
|
max_encoder_seq_len,
|
||||||
|
encoder_seq_start_loc,
|
||||||
cross_slot_mapping_tensor,
|
cross_slot_mapping_tensor,
|
||||||
cross_block_tables,
|
cross_block_tables,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user