mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) (#19904)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
84c260caeb
commit
27c065df50
@ -306,6 +306,10 @@ class MultiHeadAttention(nn.Module):
|
|||||||
block_size=16,
|
block_size=16,
|
||||||
is_attention_free=False)
|
is_attention_free=False)
|
||||||
backend = backend_name_to_enum(attn_backend.get_name())
|
backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# currently, only torch_sdpa is supported on rocm
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
else:
|
||||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||||
backend = _Backend.XFORMERS
|
backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
|||||||
@ -243,8 +243,8 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
self.runner.device, non_blocking=True)
|
self.runner.device, non_blocking=True)
|
||||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
self.runner.device, non_blocking=True)
|
self.runner.device, non_blocking=True)
|
||||||
local_max_query_len = seqlens_q_local_np.max()
|
local_max_query_len = int(seqlens_q_local_np.max())
|
||||||
local_max_seq_len = virt_k_seqlens_np.max()
|
local_max_seq_len = int(virt_k_seqlens_np.max())
|
||||||
local_scheduler_metadata = schedule(
|
local_scheduler_metadata = schedule(
|
||||||
batch_size=local_query_start_loc.shape[0] - 1,
|
batch_size=local_query_start_loc.shape[0] - 1,
|
||||||
cu_query_lens=local_query_start_loc,
|
cu_query_lens=local_query_start_loc,
|
||||||
@ -253,6 +253,17 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
max_seq_len=local_max_seq_len,
|
max_seq_len=local_max_seq_len,
|
||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
|
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device)
|
||||||
|
local_cu_seq_lens[1:] = torch.cumsum(
|
||||||
|
torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
|
device=self.runner.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
non_blocking=True),
|
||||||
|
dim=0)
|
||||||
|
|
||||||
|
|
||||||
local_attn_metadata = \
|
local_attn_metadata = \
|
||||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
local_query_start_loc=local_query_start_loc,
|
local_query_start_loc=local_query_start_loc,
|
||||||
@ -260,6 +271,7 @@ class AiterFlashAttentionMetadataBuilder:
|
|||||||
local_block_table=virt_block_table_tensor,
|
local_block_table=virt_block_table_tensor,
|
||||||
local_max_query_len=local_max_query_len,
|
local_max_query_len=local_max_query_len,
|
||||||
local_max_seq_len=local_max_seq_len,
|
local_max_seq_len=local_max_seq_len,
|
||||||
|
local_cu_seq_lens=local_cu_seq_lens,
|
||||||
local_scheduler_metadata=local_scheduler_metadata,
|
local_scheduler_metadata=local_scheduler_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -368,6 +380,7 @@ class AiterFlashAttentionMetadata:
|
|||||||
local_block_table: torch.Tensor
|
local_block_table: torch.Tensor
|
||||||
local_max_query_len: int
|
local_max_query_len: int
|
||||||
local_max_seq_len: int
|
local_max_seq_len: int
|
||||||
|
local_cu_seq_lens: torch.Tensor
|
||||||
local_scheduler_metadata: Optional[torch.Tensor]
|
local_scheduler_metadata: Optional[torch.Tensor]
|
||||||
|
|
||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||||
@ -387,6 +400,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
@ -408,6 +422,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0.
|
logits_soft_cap = 0.
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -478,12 +493,15 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
|
||||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
|
||||||
# the slot_mapping's shape to determine the number of actual tokens.
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
if self.kv_sharing_target_layer_name is None:
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||||
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
|
# actual tokens.
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -541,7 +559,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
cu_seqlens_k=cu_seq_lens,
|
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
|
||||||
|
local_metadata.local_cu_seq_lens),
|
||||||
)
|
)
|
||||||
|
|
||||||
_, num_heads, head_size = query.shape
|
_, num_heads, head_size = query.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user