mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 07:47:03 +08:00
[Kernel][AMD] Avoid D2H copy and cumsum kernel (#22683)
Signed-off-by: Xiaozhu <mxz297@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
dab4f9f764
commit
6bd8ebf026
@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
num_actual_kv_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
cu_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
@ -272,6 +274,20 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
if max_query_len > 1:
|
||||
# We pre-compute cumulative seq len needed for prefill attention
|
||||
# here to avoid recomputing it for every layer
|
||||
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=seq_lens.device)
|
||||
torch.cumsum(seq_lens,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
|
||||
else:
|
||||
cu_seq_lens = None
|
||||
num_actual_kv_tokens = 0
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
@ -281,12 +297,14 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_kv_tokens=num_actual_kv_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
@ -475,16 +493,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if max_seqlen_q > 1:
|
||||
|
||||
cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
|
||||
torch.cumsum(seqused_k,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
@ -497,10 +505,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
cu_seqlens_k=attn_metadata.cu_seq_lens,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
total_tokens=attn_metadata.total_tokens,
|
||||
total_tokens=attn_metadata.num_actual_kv_tokens,
|
||||
)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user