[Fix][FlexAttention] return max logical block index to handle reused blocks (#30915)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
Yifan Qiao 2025-12-17 22:42:21 -08:00 committed by GitHub
parent e3ab93c896
commit 11a89cf95c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 4 deletions

View File

@ -15,7 +15,10 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionMetadataBuilder,
physical_to_logical_mapping,
)
from ..models.utils import check_embeddings_close, check_logprobs_close
@ -205,5 +208,31 @@ def test_block_mask_direct_vs_slow_path():
)
def test_physical_to_logical_mapping_handles_reused_blocks():
"""Regression test: reused physical blocks map to the latest logical block.
For sliding-window / hybrid attention layers, physical KV-cache blocks can be
reused over time. The inverse mapping must therefore select the latest
logical block index for a physical block id.
"""
# Padding should not make physical block 0 look live.
block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid
out = physical_to_logical_mapping(
block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
)
assert out[0, 0].item() == -1
assert out[0, 6].item() == 0
# If a physical block id appears multiple times (block reuse), mapping should
# point to the latest logical block index.
block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
out2 = physical_to_logical_mapping(
block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
)
assert out2[0, 2].item() == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -160,7 +160,7 @@ def physical_to_logical_mapping(
If multiple logical blocks map to the same physical block,
this function returns the first (minimum) logical block index.
this function returns the latest (maximum) logical block index.
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
@ -183,6 +183,15 @@ def physical_to_logical_mapping(
To prevent this, we use seq_lens and block_size to mask out unused
entries, ensuring only valid block references are processed.
IMPORTANT: Reused physical blocks (sliding-window / hybrid attention)
For some attention types, physical cache blocks can be reused over time.
This can cause the same physical block id to appear multiple times in a row
of `block_table` at different logical block indices. In that case, only the
latest logical block index corresponds to the current contents of that
physical block. Therefore, the inverse mapping must pick the maximum logical
block index for each physical block id.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations. May contain
@ -217,8 +226,8 @@ def physical_to_logical_mapping(
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
)
physical_to_logical.scatter_(
-1, valid_block_table.to(torch.int64), valid_logical_indices
physical_to_logical.scatter_reduce_(
-1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax"
)
# NB - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1