diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1a8d2420db7a7..e1858149dde9e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -264,7 +264,7 @@ def make_local_attention_virtual_batches( np.arange(pages_per_local_batch, dtype=np.int32), (virtual_batches, pages_per_local_batch)) \ + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten() + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\