[Core] Refactor _build_attention_metadata (#29628)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-12-11 20:54:12 -05:00 committed by GitHub
parent b5945d49c0
commit 042da73244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1534,28 +1534,13 @@ class GPUModelRunner(
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None
num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs
logits_indices_padded = None
num_logits_indices = None
if logits_indices is not None:
num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
# update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
assert num_reqs_padded is not None and num_tokens_padded is not None
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
@ -1576,36 +1561,12 @@ class GPUModelRunner(
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Used in the below loop, uses padded shapes
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
kv_cache_groups = self.kv_cache_config.kv_cache_groups
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups
):
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
)
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
blk_table_tensor = torch.zeros(
(num_reqs_padded, 1),
dtype=torch.int32,
@ -1621,92 +1582,129 @@ class GPUModelRunner(
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_actual_tokens=num_tokens_padded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
return blk_table_tensor, slot_mapping
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
seq_lens=self.seq_lens.gpu[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
causal=True,
)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[
:num_reqs_padded
]
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
cm_base.num_logits_indices = logits_indices.size(0)
cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
def _build_attn_group_metadata(
kv_cache_gid: int,
attn_gid: int,
common_attn_metadata: CommonAttentionMetadata,
ubid: int | None = None,
) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet"
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
:num_reqs_padded
],
)
if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
if ubid is None:
assert isinstance(attn_metadata, dict)
attn_metadata_dict = attn_metadata
else:
assert isinstance(attn_metadata, list)
attn_metadata_dict = attn_metadata[ubid]
for layer_name in attn_group.layer_names:
attn_metadata_dict[layer_name] = attn_metadata_i
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
spec_decode_common_attn_metadata = None
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups):
cm = copy(cm_base) # shallow copy
# Basically only the encoder seq_lens, block_table and slot_mapping change
# for each kv_cache_group.
cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
)
if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = (
_get_block_table_and_slot_mapping(kv_cache_gid)
)
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = common_attn_metadata
spec_decode_common_attn_metadata = cm
else:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder()
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[
:num_reqs_padded
],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
:num_reqs_padded
],
)
spec_decode_common_attn_metadata = cm
for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata
)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_group.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)):
_build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid)
else:
assert isinstance(attn_metadata, dict)
if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
_build_attn_group_metadata(kv_cache_gid, attn_gid, cm)
if self.is_mm_prefix_lm:
req_doc_ranges = {}