mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 09:55:14 +08:00
[Attention] Make local attention backend agnostic (#21093)
This commit is contained in:
parent
b9a21e9173
commit
89cab4d01f
@ -25,9 +25,9 @@ if is_flash_attn_varlen_func_available():
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -130,18 +130,6 @@ class FlashAttentionMetadata:
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
@ -221,7 +209,6 @@ class FlashAttentionMetadataBuilder(
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
@ -266,40 +253,6 @@ class FlashAttentionMetadataBuilder(
|
||||
)
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.model_config.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.model_config.attention_chunk_size,
|
||||
query_start_loc_cpu.numpy(),
|
||||
seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
@ -371,7 +324,6 @@ class FlashAttentionMetadataBuilder(
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
@ -517,27 +469,13 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
scheduler_metadata = local_metadata.local_scheduler_metadata
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
@ -565,8 +503,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
assert not use_local_attn, (
|
||||
"Cascade attention does not support local attention.")
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
|
||||
@ -496,10 +496,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@ -514,6 +510,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
self.use_irope = use_irope
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.model_config.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.model_config.attention_chunk_size,
|
||||
query_start_loc_cpu.numpy(),
|
||||
seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max().item()
|
||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
local_cu_seq_lens[1:] = torch.cumsum(
|
||||
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
|
||||
dtype=torch.int32,
|
||||
non_blocking=True),
|
||||
dim=0)
|
||||
|
||||
|
||||
local_attn_metadata = \
|
||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_cu_seq_lens=local_cu_seq_lens,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_cu_seq_lens: torch.Tensor
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if max_seqlen_q > 1:
|
||||
cu_seq_lens = attn_metadata.cu_seq_lens
|
||||
@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
|
||||
local_metadata.local_cu_seq_lens),
|
||||
)
|
||||
cu_seqlens_k=cu_seq_lens)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
@ -18,9 +18,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -55,18 +54,6 @@ class TritonAttentionMetadata:
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.attention_chunk_size,
|
||||
common_attn_metadata.query_start_loc_cpu.numpy(),
|
||||
common_attn_metadata.seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max().item()
|
||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
||||
|
||||
local_attn_metadata = TritonAttentionMetadata \
|
||||
.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=None,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
|
||||
@ -272,11 +272,14 @@ def infer_global_hyperparameters(
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
) -> CommonAttentionMetadata:
|
||||
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
query_start_loc=query_start_loc_cpu.to(device=device,
|
||||
non_blocking=True),
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
num_reqs=len(seq_lens_cpu),
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
|
||||
@ -7,7 +7,8 @@ from typing import Callable
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -256,8 +257,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||
"FullAttentionManager can only be used for full attention groups")
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
), "FullAttentionManager can only be used for full attention " \
|
||||
"and chunked local attention groups"
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
@ -432,6 +435,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
ChunkedLocalAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
MambaSpec: MambaManager,
|
||||
}
|
||||
|
||||
@ -125,6 +125,21 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return (
|
||||
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
||||
) # noqa
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
@ -44,11 +44,14 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up)
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec, MambaSpec,
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
@ -705,6 +708,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
ChunkedLocalAttentionSpec):
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
|
||||
common_attn_metadata, self.cache_config.block_size)
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
@ -2589,6 +2598,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
use_local_attention = (self.attention_chunk_size is not None
|
||||
and attn_module.impl.use_irope)
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
@ -2597,6 +2608,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
elif use_local_attention:
|
||||
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
use_mla=use_mla))
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user