[Kernel][AMD] Avoid D2H copy and cumsum kernel (#22683)

Signed-off-by: Xiaozhu <mxz297@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Xiaozhu Meng 2025-08-12 12:53:36 -07:00 committed by GitHub
parent dab4f9f764
commit 6bd8ebf026
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
num_actual_kv_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: Optional[torch.Tensor]
# For cascade attention.
use_cascade: bool
@ -272,6 +274,20 @@ class AiterFlashAttentionMetadataBuilder(
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if max_query_len > 1:
# We pre-compute cumulative seq len needed for prefill attention
# here to avoid recomputing it for every layer
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device=seq_lens.device)
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
else:
cu_seq_lens = None
num_actual_kv_tokens = 0
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
@ -281,12 +297,14 @@ class AiterFlashAttentionMetadataBuilder(
attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_actual_kv_tokens=num_actual_kv_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
cu_seq_lens=cu_seq_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
@ -475,16 +493,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
if max_seqlen_q > 1:
cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
dtype=torch.int32,
device=query.device)
torch.cumsum(seqused_k,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
@ -497,10 +505,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
cu_seqlens_k=attn_metadata.cu_seq_lens,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
total_tokens=attn_metadata.total_tokens,
total_tokens=attn_metadata.num_actual_kv_tokens,
)
_, num_heads, head_size = query.shape