From 11a89cf95caaec8dec13fab1e8e3d64c9a852a08 Mon Sep 17 00:00:00 2001 From: Yifan Qiao Date: Wed, 17 Dec 2025 22:42:21 -0800 Subject: [PATCH] [Fix][FlexAttention] return max logical block index to handle reused blocks (#30915) Signed-off-by: Yifan Qiao --- tests/kernels/test_flex_attention.py | 31 +++++++++++++++++++- vllm/v1/attention/backends/flex_attention.py | 15 ++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index f6987d54399d2..7053a8697e190 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -15,7 +15,10 @@ from tests.v1.attention.utils import ( create_standard_kv_cache_spec, create_vllm_config, ) -from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionMetadataBuilder, + physical_to_logical_mapping, +) from ..models.utils import check_embeddings_close, check_logprobs_close @@ -205,5 +208,31 @@ def test_block_mask_direct_vs_slow_path(): ) +def test_physical_to_logical_mapping_handles_reused_blocks(): + """Regression test: reused physical blocks map to the latest logical block. + + For sliding-window / hybrid attention layers, physical KV-cache blocks can be + reused over time. The inverse mapping must therefore select the latest + logical block index for a physical block id. + """ + # Padding should not make physical block 0 look live. + block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32) + seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid + out = physical_to_logical_mapping( + block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10 + ) + assert out[0, 0].item() == -1 + assert out[0, 6].item() == 0 + + # If a physical block id appears multiple times (block reuse), mapping should + # point to the latest logical block index. + block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32) + seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32) + out2 = physical_to_logical_mapping( + block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8 + ) + assert out2[0, 2].item() == 1 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d8dbe4cbae013..8193c05c2b1ab 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -160,7 +160,7 @@ def physical_to_logical_mapping( └───────────────────────────────────────────┘ If multiple logical blocks map to the same physical block, - this function returns the first (minimum) logical block index. + this function returns the latest (maximum) logical block index. If a physical block is not mapped to by any logical block, its value in the result will be -1. @@ -183,6 +183,15 @@ def physical_to_logical_mapping( To prevent this, we use seq_lens and block_size to mask out unused entries, ensuring only valid block references are processed. + IMPORTANT: Reused physical blocks (sliding-window / hybrid attention) + ──────────────────────────────────────────────────────────────────── + For some attention types, physical cache blocks can be reused over time. + This can cause the same physical block id to appear multiple times in a row + of `block_table` at different logical block indices. In that case, only the + latest logical block index corresponds to the current contents of that + physical block. Therefore, the inverse mapping must pick the maximum logical + block index for each physical block id. + Args: block_table: Tensor of shape [max_reqs, max_num_blocks] mapping logical blocks to physical locations. May contain @@ -217,8 +226,8 @@ def physical_to_logical_mapping( mask, torch.arange(max_num_blocks, device=device)[None, :], 0 ) - physical_to_logical.scatter_( - -1, valid_block_table.to(torch.int64), valid_logical_indices + physical_to_logical.scatter_reduce_( + -1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax" ) # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1