[Attention][Async] Eliminate seq_lens_cpu in FlashAttention metadata building with DCP > 1 (#29449)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-11-26 21:48:43 -05:00 committed by GitHub
parent df01eda4dc
commit 77740191de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 15 deletions

View File

@ -328,7 +328,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
@ -401,20 +400,23 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata = None
if self.dcp_world_size > 1:
query_kv_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
dcp_context_kv_lens = seq_lens - query_kv_lens
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
dcp_context_kv_lens = get_dcp_local_seq_lens(
dcp_context_kv_lens,
self.dcp_world_size,
self.dcp_rank,
self.cp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_dcp_context_kv_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
scheduler_metadata = schedule(
batch_size=num_reqs,
@ -431,9 +433,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True
)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,

View File

@ -1095,12 +1095,14 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_size, dtype=torch.int32)
torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
rank_offsets = torch.tensor(
[[dcp_rank]], dtype=torch.int32, device=seq_lens.device
)
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)