From 073a4bd1c04164af29843cb5478740e9839d2d8a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 1 Dec 2024 17:55:39 -0800 Subject: [PATCH] [Kernel] Use `out` arg in flash_attn_varlen_func (#10811) Signed-off-by: Woosuk Kwon --- CMakeLists.txt | 2 +- tests/kernels/test_flash_attn.py | 20 +++++++++++++++++--- vllm/v1/attention/backends/flash_attn.py | 6 +++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f43bf8143458..c78cdc77a7e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -522,7 +522,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG fdf6d72b48aea41f4ae6a89139a453dae554abc8 + GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index a20c73345218..1ae78d7b46c5 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -71,6 +71,7 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -81,6 +82,7 @@ def ref_paged_attn( @pytest.mark.parametrize("sliding_window", [None, 256]) @torch.inference_mode() def test_flash_attn_with_paged_kv( + use_out: bool, kv_lens: List[int], num_heads: Tuple[int, int], head_size: int, @@ -116,17 +118,22 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + q = query.unsqueeze(1) + out = torch.empty_like(q) if use_out else None output = flash_attn_with_kvcache( - q=query.unsqueeze(1), + q=q, k_cache=key_cache, v_cache=value_cache, + out=out, softmax_scale=scale, causal=True, block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, - ).squeeze(1) + ) + output = output if not use_out else out + output = output.squeeze(1) ref_output = ref_paged_attn(query=query, key_cache=key_cache, @@ -141,7 +148,10 @@ def test_flash_attn_with_paged_kv( f"{torch.max(torch.abs(output - ref_output))}" -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("use_out", [True, False]) +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -151,6 +161,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_varlen_with_paged_kv( + use_out: bool, seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], head_size: int, @@ -197,10 +208,12 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + out = torch.empty_like(query) if use_out else None output = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, + out=out, cu_seqlens_q=cu_query_lens, cu_seqlens_k=cu_kv_lens, max_seqlen_q=max_query_len, @@ -211,6 +224,7 @@ def test_varlen_with_paged_kv( block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, ) + output = output if not use_out else out ref_output = ref_paged_attn( query=query, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e618edf7d35b..4aa4b296f0ef 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -205,10 +205,12 @@ def unified_v1_flash_attention( v_scale, ) - attn_output = flash_attn_varlen_func( + # Compute attention and update output up to `num_actual_tokens`. + flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, v=value_cache, + out=output[:num_actual_tokens], cu_seqlens_q=attn_metadata.query_start_loc, max_seqlen_q=attn_metadata.max_query_len, cu_seqlens_k=attn_metadata.seq_start_loc, @@ -220,8 +222,6 @@ def unified_v1_flash_attention( block_table=attn_metadata.block_table, softcap=logits_soft_cap, ) - # TODO(woosuk): Remove this unnecessary copy. - output[:num_actual_tokens].copy_(attn_output) def unified_v1_flash_attention_fake(