[Attention] Make local attention backend agnostic (#21093)

This commit is contained in:
Lucas Wilkinson 2025-07-18 00:10:42 -04:00 committed by GitHub
parent b9a21e9173
commit 89cab4d01f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 94 additions and 242 deletions

View File

@ -25,9 +25,9 @@ if is_flash_attn_varlen_func_available():
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@ -130,18 +130,6 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@ -221,7 +209,6 @@ class FlashAttentionMetadataBuilder(
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
@ -266,40 +253,6 @@ class FlashAttentionMetadataBuilder(
)
return None
# for local attention
local_attn_metadata = None
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
)
use_cascade = common_prefix_len > 0
if use_cascade:
@ -371,7 +324,6 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
@ -517,27 +469,13 @@ class FlashAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
scheduler_metadata = local_metadata.local_scheduler_metadata
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
@ -565,8 +503,6 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],

View File

@ -496,10 +496,6 @@ class FlashInferImpl(AttentionImpl):
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@ -514,6 +510,7 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

View File

@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
return None
# for local attention
local_attn_metadata = None
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
dtype=torch.int32,
non_blocking=True),
dim=0)
local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata
@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
class AiterFlashAttentionImpl(AttentionImpl):
@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if max_seqlen_q > 1:
cu_seq_lens = attn_metadata.cu_seq_lens
@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
)
cu_seqlens_k=cu_seq_lens)
_, num_heads, head_size = query.shape
_PARTITION_SIZE_ROCM = 256

View File

@ -18,9 +18,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@ -55,18 +54,6 @@ class TritonAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# for local attention
local_attn_metadata = None
if self.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.attention_chunk_size,
common_attn_metadata.query_start_loc_cpu.numpy(),
common_attn_metadata.seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_attn_metadata = TritonAttentionMetadata \
.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=None,
)
use_cascade = common_prefix_len > 0
if use_cascade:
@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.

View File

@ -272,11 +272,14 @@ def infer_global_hyperparameters(
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
) -> CommonAttentionMetadata:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
device = common_attn_metadata.query_start_loc.device
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size,
dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
(rarange * attn_chunk_size + \
@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
block_table_local = block_table[batch_indices, block_indices]\
.view(virtual_batches, -1)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
block_table_local
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device,
non_blocking=True),
seq_lens_cpu=seq_lens_cpu,
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
)
def split_decodes_and_prefills(

View File

@ -7,7 +7,8 @@ from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.request import Request
@ -256,8 +257,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
), "FullAttentionManager can only be used for full attention " \
"and chunked local attention groups"
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
@ -432,6 +435,7 @@ class MambaManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
ChunkedLocalAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
MambaSpec: MambaManager,
}

View File

@ -125,6 +125,21 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@dataclass
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@property
def type_id(self) -> str:
return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int

View File

@ -44,11 +44,14 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec,
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
@ -705,6 +708,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
if isinstance(kv_cache_group_spec.kv_cache_spec,
ChunkedLocalAttentionSpec):
common_attn_metadata = make_local_attention_virtual_batches(
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
common_attn_metadata, self.cache_config.block_size)
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
@ -2589,6 +2598,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: Support other attention modules, e.g., cross-attention
if attn_module.attn_type == AttentionType.DECODER:
use_local_attention = (self.attention_chunk_size is not None
and attn_module.impl.use_irope)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
@ -2597,6 +2608,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
elif use_local_attention:
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
use_mla=use_mla))
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,