mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 08:55:45 +08:00
[Attention] Make local attention backend agnostic (#21093)
This commit is contained in:
parent
b9a21e9173
commit
89cab4d01f
@ -25,9 +25,9 @@ if is_flash_attn_varlen_func_available():
|
|||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
CommonAttentionMetadata,
|
||||||
make_local_attention_virtual_batches)
|
get_kv_cache_layout)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -130,18 +130,6 @@ class FlashAttentionMetadata:
|
|||||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
max_num_splits: int = 0
|
max_num_splits: int = 0
|
||||||
|
|
||||||
# for local attention
|
|
||||||
@dataclass
|
|
||||||
class LocalAttentionMetadata:
|
|
||||||
local_query_start_loc: torch.Tensor
|
|
||||||
local_seqused_k: torch.Tensor
|
|
||||||
local_block_table: torch.Tensor
|
|
||||||
local_max_query_len: int
|
|
||||||
local_max_seq_len: int
|
|
||||||
local_scheduler_metadata: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_sliding_window_configs(
|
def _get_sliding_window_configs(
|
||||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||||
@ -221,7 +209,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
@ -266,40 +253,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# for local attention
|
|
||||||
local_attn_metadata = None
|
|
||||||
if self.model_config.attention_chunk_size is not None:
|
|
||||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
|
||||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
|
||||||
self.model_config.attention_chunk_size,
|
|
||||||
query_start_loc_cpu.numpy(),
|
|
||||||
seq_lens_cpu.numpy(),
|
|
||||||
block_table_tensor,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_max_query_len = seqlens_q_local_np.max()
|
|
||||||
local_max_seq_len = virt_k_seqlens_np.max()
|
|
||||||
local_scheduler_metadata = schedule(
|
|
||||||
batch_size=local_query_start_loc.shape[0] - 1,
|
|
||||||
cu_query_lens=local_query_start_loc,
|
|
||||||
max_query_len=local_max_query_len,
|
|
||||||
seqlens=local_seqused_k,
|
|
||||||
max_seq_len=local_max_seq_len,
|
|
||||||
causal=True)
|
|
||||||
|
|
||||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
|
||||||
local_query_start_loc=local_query_start_loc,
|
|
||||||
local_seqused_k=local_seqused_k,
|
|
||||||
local_block_table=virt_block_table_tensor,
|
|
||||||
local_max_query_len=local_max_query_len,
|
|
||||||
local_max_seq_len=local_max_seq_len,
|
|
||||||
local_scheduler_metadata=local_scheduler_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
@ -371,7 +324,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
local_attn_metadata=local_attn_metadata,
|
|
||||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
max_num_splits=max_num_splits,
|
max_num_splits=max_num_splits,
|
||||||
)
|
)
|
||||||
@ -517,27 +469,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
layer._q_scale)
|
layer._q_scale)
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
if not attn_metadata.use_cascade:
|
||||||
use_local_attn = \
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
seqused_k = attn_metadata.seq_lens
|
||||||
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
if not attn_metadata.use_cascade or use_local_attn:
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
if use_local_attn:
|
block_table = attn_metadata.block_table
|
||||||
assert attn_metadata.local_attn_metadata is not None
|
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||||
local_metadata = attn_metadata.local_attn_metadata
|
|
||||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
|
||||||
seqused_k = local_metadata.local_seqused_k
|
|
||||||
max_seqlen_q = local_metadata.local_max_query_len
|
|
||||||
max_seqlen_k = local_metadata.local_max_seq_len
|
|
||||||
block_table = local_metadata.local_block_table
|
|
||||||
scheduler_metadata = local_metadata.local_scheduler_metadata
|
|
||||||
else:
|
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
|
||||||
seqused_k = attn_metadata.seq_lens
|
|
||||||
max_seqlen_q = attn_metadata.max_query_len
|
|
||||||
max_seqlen_k = attn_metadata.max_seq_len
|
|
||||||
block_table = attn_metadata.block_table
|
|
||||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
|
|
||||||
@ -565,8 +503,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
assert not use_local_attn, (
|
|
||||||
"Cascade attention does not support local attention.")
|
|
||||||
# Cascade attention (rare case).
|
# Cascade attention (rare case).
|
||||||
cascade_attention(
|
cascade_attention(
|
||||||
output[:num_actual_tokens],
|
output[:num_actual_tokens],
|
||||||
|
|||||||
@ -496,10 +496,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if use_irope:
|
|
||||||
logger.warning_once(
|
|
||||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
|
||||||
" back to global attention for long context.")
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@ -514,6 +510,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
self.use_irope = use_irope
|
||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
|||||||
@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
|
||||||
make_local_attention_virtual_batches)
|
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||||
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
|
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
dtype=cu_seq_lens.dtype,
|
dtype=cu_seq_lens.dtype,
|
||||||
out=cu_seq_lens[1:])
|
out=cu_seq_lens[1:])
|
||||||
|
|
||||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
|
||||||
max_seq_len, causal):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# for local attention
|
|
||||||
local_attn_metadata = None
|
|
||||||
if self.model_config.attention_chunk_size is not None:
|
|
||||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
|
||||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
|
||||||
self.model_config.attention_chunk_size,
|
|
||||||
query_start_loc_cpu.numpy(),
|
|
||||||
seq_lens_cpu.numpy(),
|
|
||||||
block_table_tensor,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_max_query_len = seqlens_q_local_np.max().item()
|
|
||||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
|
||||||
local_scheduler_metadata = schedule(
|
|
||||||
batch_size=local_query_start_loc.shape[0] - 1,
|
|
||||||
cu_query_lens=local_query_start_loc,
|
|
||||||
max_query_len=local_max_query_len,
|
|
||||||
seqlens=local_seqused_k,
|
|
||||||
max_seq_len=local_max_seq_len,
|
|
||||||
causal=True)
|
|
||||||
|
|
||||||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
local_cu_seq_lens[1:] = torch.cumsum(
|
|
||||||
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
non_blocking=True),
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
local_attn_metadata = \
|
|
||||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
|
||||||
local_query_start_loc=local_query_start_loc,
|
|
||||||
local_seqused_k=local_seqused_k,
|
|
||||||
local_block_table=virt_block_table_tensor,
|
|
||||||
local_max_query_len=local_max_query_len,
|
|
||||||
local_max_seq_len=local_max_seq_len,
|
|
||||||
local_cu_seq_lens=local_cu_seq_lens,
|
|
||||||
local_scheduler_metadata=local_scheduler_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
cu_prefix_query_lens = None
|
cu_prefix_query_lens = None
|
||||||
@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
local_attn_metadata=local_attn_metadata,
|
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
|
|||||||
prefix_kv_lens: Optional[torch.Tensor]
|
prefix_kv_lens: Optional[torch.Tensor]
|
||||||
suffix_kv_lens: Optional[torch.Tensor]
|
suffix_kv_lens: Optional[torch.Tensor]
|
||||||
|
|
||||||
# for local attention
|
|
||||||
@dataclass
|
|
||||||
class LocalAttentionMetadata:
|
|
||||||
local_query_start_loc: torch.Tensor
|
|
||||||
local_seqused_k: torch.Tensor
|
|
||||||
local_block_table: torch.Tensor
|
|
||||||
local_max_query_len: int
|
|
||||||
local_max_seq_len: int
|
|
||||||
local_cu_seq_lens: torch.Tensor
|
|
||||||
local_scheduler_metadata: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AiterFlashAttentionImpl(AttentionImpl):
|
class AiterFlashAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
layer._q_scale)
|
layer._q_scale)
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
if not attn_metadata.use_cascade:
|
||||||
use_local_attn = \
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
seqused_k = attn_metadata.seq_lens
|
||||||
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
if not attn_metadata.use_cascade or use_local_attn:
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
if use_local_attn:
|
block_table = attn_metadata.block_table
|
||||||
assert attn_metadata.local_attn_metadata is not None
|
|
||||||
local_metadata = attn_metadata.local_attn_metadata
|
|
||||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
|
||||||
seqused_k = local_metadata.local_seqused_k
|
|
||||||
max_seqlen_q = local_metadata.local_max_query_len
|
|
||||||
max_seqlen_k = local_metadata.local_max_seq_len
|
|
||||||
block_table = local_metadata.local_block_table
|
|
||||||
else:
|
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
|
||||||
seqused_k = attn_metadata.seq_lens
|
|
||||||
max_seqlen_q = attn_metadata.max_query_len
|
|
||||||
max_seqlen_k = attn_metadata.max_seq_len
|
|
||||||
block_table = attn_metadata.block_table
|
|
||||||
|
|
||||||
if max_seqlen_q > 1:
|
if max_seqlen_q > 1:
|
||||||
cu_seq_lens = attn_metadata.cu_seq_lens
|
cu_seq_lens = attn_metadata.cu_seq_lens
|
||||||
@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
|
cu_seqlens_k=cu_seq_lens)
|
||||||
local_metadata.local_cu_seq_lens),
|
|
||||||
)
|
|
||||||
|
|
||||||
_, num_heads, head_size = query.shape
|
_, num_heads, head_size = query.shape
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
|||||||
@ -18,9 +18,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
CommonAttentionMetadata)
|
||||||
make_local_attention_virtual_batches)
|
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -55,18 +54,6 @@ class TritonAttentionMetadata:
|
|||||||
scheduler_metadata: Optional[torch.Tensor] = None
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# for local attention
|
|
||||||
@dataclass
|
|
||||||
class LocalAttentionMetadata:
|
|
||||||
local_query_start_loc: torch.Tensor
|
|
||||||
local_seqused_k: torch.Tensor
|
|
||||||
local_block_table: torch.Tensor
|
|
||||||
local_max_query_len: int
|
|
||||||
local_max_seq_len: int
|
|
||||||
local_scheduler_metadata: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(
|
class TritonAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
|
|||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
# for local attention
|
|
||||||
local_attn_metadata = None
|
|
||||||
if self.attention_chunk_size is not None:
|
|
||||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
|
||||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
|
||||||
self.attention_chunk_size,
|
|
||||||
common_attn_metadata.query_start_loc_cpu.numpy(),
|
|
||||||
common_attn_metadata.seq_lens_cpu.numpy(),
|
|
||||||
block_table_tensor,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
|
||||||
self.device, non_blocking=True)
|
|
||||||
local_max_query_len = seqlens_q_local_np.max().item()
|
|
||||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
|
||||||
|
|
||||||
local_attn_metadata = TritonAttentionMetadata \
|
|
||||||
.LocalAttentionMetadata(
|
|
||||||
local_query_start_loc=local_query_start_loc,
|
|
||||||
local_seqused_k=local_seqused_k,
|
|
||||||
local_block_table=virt_block_table_tensor,
|
|
||||||
local_max_query_len=local_max_query_len,
|
|
||||||
local_max_seq_len=local_max_seq_len,
|
|
||||||
local_scheduler_metadata=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
|
|||||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
local_attn_metadata=local_attn_metadata,
|
|
||||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
layer._q_scale)
|
layer._q_scale)
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
use_local_attn = \
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
seqused_k = attn_metadata.seq_lens
|
||||||
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
if use_local_attn:
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
assert attn_metadata.local_attn_metadata is not None
|
block_table = attn_metadata.block_table
|
||||||
local_metadata = attn_metadata.local_attn_metadata
|
|
||||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
|
||||||
seqused_k = local_metadata.local_seqused_k
|
|
||||||
max_seqlen_q = local_metadata.local_max_query_len
|
|
||||||
max_seqlen_k = local_metadata.local_max_seq_len
|
|
||||||
block_table = local_metadata.local_block_table
|
|
||||||
else:
|
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
|
||||||
seqused_k = attn_metadata.seq_lens
|
|
||||||
max_seqlen_q = attn_metadata.max_query_len
|
|
||||||
max_seqlen_k = attn_metadata.max_seq_len
|
|
||||||
block_table = attn_metadata.block_table
|
|
||||||
|
|
||||||
if use_prefill_decode_attn:
|
if use_prefill_decode_attn:
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
|
|||||||
@ -272,11 +272,14 @@ def infer_global_hyperparameters(
|
|||||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||||
def make_local_attention_virtual_batches(
|
def make_local_attention_virtual_batches(
|
||||||
attn_chunk_size: int,
|
attn_chunk_size: int,
|
||||||
query_start_loc_np: np.ndarray,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
seq_lens_np: np.ndarray,
|
|
||||||
block_table: torch.Tensor,
|
|
||||||
block_size: int = 0,
|
block_size: int = 0,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
) -> CommonAttentionMetadata:
|
||||||
|
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||||
|
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||||
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
|
device = common_attn_metadata.query_start_loc.device
|
||||||
|
|
||||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
actual_batch_size = seq_lens_np.shape[0]
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
|
||||||
@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
|
|||||||
attn_chunk_size,
|
attn_chunk_size,
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||||
|
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
|
||||||
|
|
||||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||||
(rarange * attn_chunk_size + \
|
(rarange * attn_chunk_size + \
|
||||||
@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
|
|||||||
block_table_local = block_table[batch_indices, block_indices]\
|
block_table_local = block_table[batch_indices, block_indices]\
|
||||||
.view(virtual_batches, -1)
|
.view(virtual_batches, -1)
|
||||||
|
|
||||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||||
block_table_local
|
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||||
|
|
||||||
|
return CommonAttentionMetadata(
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
query_start_loc=query_start_loc_cpu.to(device=device,
|
||||||
|
non_blocking=True),
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
|
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||||
|
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||||
|
num_reqs=len(seq_lens_cpu),
|
||||||
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||||
|
max_query_len=seqlens_q_local.max(),
|
||||||
|
block_table_tensor=block_table_local,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_decodes_and_prefills(
|
def split_decodes_and_prefills(
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from typing import Callable
|
|||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||||
|
FullAttentionSpec, KVCacheSpec,
|
||||||
MambaSpec, SlidingWindowSpec)
|
MambaSpec, SlidingWindowSpec)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
@ -256,8 +257,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
assert isinstance(
|
||||||
"FullAttentionManager can only be used for full attention groups")
|
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||||
|
), "FullAttentionManager can only be used for full attention " \
|
||||||
|
"and chunked local attention groups"
|
||||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
[] for _ in range(len(kv_cache_group_ids)))
|
[] for _ in range(len(kv_cache_group_ids)))
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
@ -432,6 +435,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
|
|
||||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||||
FullAttentionSpec: FullAttentionManager,
|
FullAttentionSpec: FullAttentionManager,
|
||||||
|
ChunkedLocalAttentionSpec: FullAttentionManager,
|
||||||
SlidingWindowSpec: SlidingWindowManager,
|
SlidingWindowSpec: SlidingWindowManager,
|
||||||
MambaSpec: MambaManager,
|
MambaSpec: MambaManager,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,6 +125,21 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
return merged_spec
|
return merged_spec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||||
|
attention_chunk_size: int
|
||||||
|
|
||||||
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type_id(self) -> str:
|
||||||
|
return (
|
||||||
|
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlidingWindowSpec(AttentionSpec):
|
class SlidingWindowSpec(AttentionSpec):
|
||||||
sliding_window: int
|
sliding_window: int
|
||||||
|
|||||||
@ -44,11 +44,14 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up)
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata)
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
|
make_local_attention_virtual_batches)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
KVCacheConfig, KVCacheSpec, MambaSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
|
FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheSpec, MambaSpec,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||||
ModelRunnerOutput)
|
ModelRunnerOutput)
|
||||||
@ -705,6 +708,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_decode_common_attn_metadata is None:
|
spec_decode_common_attn_metadata is None:
|
||||||
spec_decode_common_attn_metadata = common_attn_metadata
|
spec_decode_common_attn_metadata = common_attn_metadata
|
||||||
|
|
||||||
|
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||||
|
ChunkedLocalAttentionSpec):
|
||||||
|
common_attn_metadata = make_local_attention_virtual_batches(
|
||||||
|
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
|
||||||
|
common_attn_metadata, self.cache_config.block_size)
|
||||||
|
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
# Prepare for cascade attention if enabled & beneficial.
|
||||||
common_prefix_len = 0
|
common_prefix_len = 0
|
||||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||||
@ -2589,6 +2598,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# TODO: Support other attention modules, e.g., cross-attention
|
# TODO: Support other attention modules, e.g., cross-attention
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
|
use_local_attention = (self.attention_chunk_size is not None
|
||||||
|
and attn_module.impl.use_irope)
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -2597,6 +2608,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
sliding_window=attn_module.sliding_window,
|
sliding_window=attn_module.sliding_window,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
|
elif use_local_attention:
|
||||||
|
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
attention_chunk_size=self.attention_chunk_size,
|
||||||
|
use_mla=use_mla))
|
||||||
else:
|
else:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user