mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 17:23:09 +08:00
[Attention][Async] Eliminate seq_lens_cpu in FlashAttention metadata building with DCP > 1 (#29449)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
df01eda4dc
commit
77740191de
@ -328,7 +328,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
@ -401,20 +400,23 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
||||
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
dcp_context_kv_lens = seq_lens - query_kv_lens
|
||||
|
||||
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
|
||||
dcp_context_kv_lens_cpu,
|
||||
dcp_context_kv_lens = get_dcp_local_seq_lens(
|
||||
dcp_context_kv_lens,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||
# After DCP distribution, the maximum number of tokens for any rank is
|
||||
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
|
||||
# and I is cp_kv_cache_interleave_size.
|
||||
# This eliminates GPU->CPU sync while minimizing workspace over-allocation.
|
||||
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
|
||||
max_dcp_context_kv_len = (
|
||||
(max_seq_len + num_partitions - 1) // num_partitions
|
||||
) * self.cp_kv_cache_interleave_size
|
||||
|
||||
scheduler_metadata = schedule(
|
||||
batch_size=num_reqs,
|
||||
@ -431,9 +433,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
|
||||
@ -1095,12 +1095,14 @@ def get_dcp_local_seq_lens(
|
||||
num_requests = seq_lens.size(0)
|
||||
if dcp_rank is None:
|
||||
rank_offsets = (
|
||||
torch.arange(dcp_size, dtype=torch.int32)
|
||||
torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(num_requests, 1)
|
||||
)
|
||||
else:
|
||||
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
|
||||
rank_offsets = torch.tensor(
|
||||
[[dcp_rank]], dtype=torch.int32, device=seq_lens.device
|
||||
)
|
||||
seq_lens_tiled = (
|
||||
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user