mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 10:07:11 +08:00
[Core] Refactor _build_attention_metadata (#29628)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
b5945d49c0
commit
042da73244
@ -1534,28 +1534,13 @@ class GPUModelRunner(
|
||||
"""
|
||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||
"""
|
||||
# Attention metadata is not needed for attention free models
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return {}, None
|
||||
|
||||
num_tokens_padded = num_tokens_padded or num_tokens
|
||||
num_reqs_padded = num_reqs_padded or num_reqs
|
||||
|
||||
logits_indices_padded = None
|
||||
num_logits_indices = None
|
||||
if logits_indices is not None:
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices
|
||||
)
|
||||
|
||||
# update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
|
||||
self.seq_lens.cpu[:num_reqs],
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
|
||||
assert num_reqs_padded is not None and num_tokens_padded is not None
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
@ -1576,36 +1561,12 @@ class GPUModelRunner(
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
# Used in the below loop, uses padded shapes
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs_padded
|
||||
]
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
|
||||
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
|
||||
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_gid, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups
|
||||
):
|
||||
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
|
||||
num_scheduled_tokens or {},
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs_padded,
|
||||
)
|
||||
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
|
||||
assert num_reqs_padded is not None and num_tokens_padded is not None
|
||||
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs_padded, 1),
|
||||
dtype=torch.int32,
|
||||
@ -1621,92 +1582,129 @@ class GPUModelRunner(
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_actual_tokens=num_tokens_padded,
|
||||
num_reqs=num_reqs_padded,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=num_logits_indices,
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
|
||||
return blk_table_tensor, slot_mapping
|
||||
|
||||
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
|
||||
cm_base = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
|
||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
|
||||
seq_lens=self.seq_lens.gpu[:num_reqs_padded],
|
||||
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
|
||||
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs_padded
|
||||
],
|
||||
num_reqs=num_reqs_padded,
|
||||
num_actual_tokens=num_tokens_padded,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_gid_0,
|
||||
slot_mapping=slot_mapping_gid_0,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
|
||||
self.seq_lens.cpu[:num_reqs],
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
|
||||
|
||||
cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
|
||||
cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[
|
||||
:num_reqs_padded
|
||||
]
|
||||
|
||||
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
|
||||
cm_base.num_logits_indices = logits_indices.size(0)
|
||||
cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices
|
||||
)
|
||||
|
||||
def _build_attn_group_metadata(
|
||||
kv_cache_gid: int,
|
||||
attn_gid: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
ubid: int | None = None,
|
||||
) -> None:
|
||||
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
|
||||
cascade_attn_prefix_len = (
|
||||
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
|
||||
builder = attn_group.get_metadata_builder(ubid or 0)
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
assert ubid is None, "UBatching not supported with GDN yet"
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
|
||||
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
|
||||
:num_reqs_padded
|
||||
],
|
||||
)
|
||||
|
||||
if for_cudagraph_capture:
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
|
||||
if ubid is None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata_dict = attn_metadata
|
||||
else:
|
||||
assert isinstance(attn_metadata, list)
|
||||
attn_metadata_dict = attn_metadata[ubid]
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata_dict[layer_name] = attn_metadata_i
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
spec_decode_common_attn_metadata = None
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups):
|
||||
cm = copy(cm_base) # shallow copy
|
||||
|
||||
# Basically only the encoder seq_lens, block_table and slot_mapping change
|
||||
# for each kv_cache_group.
|
||||
cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens(
|
||||
num_scheduled_tokens or {},
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs_padded,
|
||||
)
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = (
|
||||
_get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
)
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
|
||||
cascade_attn_prefix_len = (
|
||||
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
builder = attn_group.get_metadata_builder()
|
||||
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.gpu[
|
||||
:num_reqs_padded
|
||||
],
|
||||
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
|
||||
:num_reqs_padded
|
||||
],
|
||||
)
|
||||
spec_decode_common_attn_metadata = cm
|
||||
|
||||
for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
|
||||
if ubatch_slices is not None:
|
||||
common_attn_metadata_list = split_attn_metadata(
|
||||
ubatch_slices, common_attn_metadata
|
||||
)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list
|
||||
):
|
||||
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
|
||||
if for_cudagraph_capture:
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)):
|
||||
_build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid)
|
||||
|
||||
else:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
if for_cudagraph_capture:
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
_build_attn_group_metadata(kv_cache_gid, attn_gid, cm)
|
||||
|
||||
if self.is_mm_prefix_lm:
|
||||
req_doc_ranges = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user