mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) (#19904)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
84c260caeb
commit
27c065df50
@ -306,12 +306,16 @@ class MultiHeadAttention(nn.Module):
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
if current_platform.is_rocm():
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -243,8 +243,8 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_max_query_len = int(seqlens_q_local_np.max())
|
||||
local_max_seq_len = int(virt_k_seqlens_np.max())
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
@ -253,6 +253,17 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
local_cu_seq_lens[1:] = torch.cumsum(
|
||||
torch.from_numpy(virt_k_seqlens_np).to(
|
||||
device=self.runner.device,
|
||||
dtype=torch.int32,
|
||||
non_blocking=True),
|
||||
dim=0)
|
||||
|
||||
|
||||
local_attn_metadata = \
|
||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
@ -260,6 +271,7 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_cu_seq_lens=local_cu_seq_lens,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
@ -368,6 +380,7 @@ class AiterFlashAttentionMetadata:
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_cu_seq_lens: torch.Tensor
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
@ -387,6 +400,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
@ -408,6 +422,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0.
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -478,22 +493,25 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
||||
# the slot_mapping's shape to determine the number of actual tokens.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
@ -541,7 +559,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
|
||||
local_metadata.local_cu_seq_lens),
|
||||
)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user