[Core] Separate out attention metadata building logic from prepare inputs (#26764)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-11-09 13:51:43 -05:00 committed by GitHub
parent 289eb6c537
commit 636efd10a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1054,7 +1054,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_encoder_seq_lens(
self,
scheduler_output: "SchedulerOutput",
scheduled_encoder_inputs: dict[str, list[int]],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
) -> np.ndarray | None:
@ -1064,31 +1064,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs:
for req_id in scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
return encoder_seq_lens
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
self,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray,
max_num_scheduled_tokens: int,
) -> tuple[
PerLayerAttnMetadata,
torch.Tensor,
SpecDecodeMetadata | None,
np.ndarray,
CommonAttentionMetadata | None,
int,
UBatchSlices | None,
torch.Tensor | None,
bool,
]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, use_cascade_attn
ubatch_slices, num_tokens_across_dp,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1100,12 +1096,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
@ -1232,8 +1222,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Fill unused with 0 for full cuda graph mode.
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
@ -1305,11 +1293,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu()
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
# Hot-Swap lora model
if self.lora_config:
assert (
np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
)
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
return (
logits_indices,
spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
)
def _build_attention_metadata(
self,
total_num_scheduled_tokens: int,
max_num_scheduled_tokens: int,
num_reqs: int,
ubatch_slices: UBatchSlices | None = None,
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
for_cudagraph_capture: bool = False,
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
logits_indices_padded = None
num_logits_indices = 0
if logits_indices is not None:
num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
# update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
@ -1324,15 +1347,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
use_cascade_attn = False
# Used in the below loop.
# Used in the below loop
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
seq_lens = self.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
]
dcp_local_seq_lens = (
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
)
spec_decode_common_attn_metadata = None
if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding
# window size when capturing to make sure the correct kernel is selected.
max_seq_len = self.max_model_len
else:
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
@ -1342,14 +1378,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups
):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs
scheduled_encoder_inputs or {},
kv_cache_group.kv_cache_spec,
num_reqs,
)
if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
@ -1362,18 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int64,
device=self.device,
)
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[
kv_cache_group_id
]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
@ -1388,35 +1422,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
dcp_local_seq_lens=dcp_local_seq_lens,
)
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if (
self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names
):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = common_attn_metadata
else:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder()
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks,
attn_group.kv_cache_spec,
builder,
)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
@ -1434,51 +1459,69 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
attn_metadata_i = attn_group.get_metadata_builder(
ubatch_id=ubid
).build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_group_spec.layer_names:
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_group.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
assert isinstance(attn_metadata, dict)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False)
if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# disable cascade attention when DBO
if ubatch_slices is not None:
use_cascade_attn = False
return attn_metadata, spec_decode_common_attn_metadata
# Hot-Swap lora model
if self.lora_config:
assert (
np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
)
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
def _compute_cascade_attn_prefix_lens(
self,
num_scheduled_tokens: np.ndarray,
num_common_prefix_blocks: list[int],
) -> list[list[int]] | None:
"""
:return: Optional[cascade_attn_prefix_lens]
cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
None if we should not use cascade attention
"""
return (
attn_metadata,
logits_indices,
spec_decode_metadata,
num_scheduled_tokens,
spec_decode_common_attn_metadata,
max_num_scheduled_tokens,
ubatch_slices,
num_tokens_across_dp,
use_cascade_attn,
)
use_cascade_attn = False
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
cascade_attn_prefix_lens: list[list[int]] = [
[] for _ in range(num_kv_cache_groups)
]
for kv_cache_gid in range(num_kv_cache_groups):
for attn_group in self.attn_groups[kv_cache_gid]:
if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
cascade_attn_prefix_len = 0
else:
# 0 if cascade attention should not be used
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks[kv_cache_gid],
attn_group.kv_cache_spec,
attn_group.get_metadata_builder(),
)
cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
use_cascade_attn |= cascade_attn_prefix_len > 0
return cascade_attn_prefix_lens if use_cascade_attn else None
def _compute_cascade_attn_prefix_len(
self,
@ -1504,6 +1547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
if common_prefix_len == 0:
# Common case.
@ -2497,18 +2541,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"it when the requests need prompt logprobs"
)
# Prepare the decoder inputs.
num_reqs = self.input_batch.num_reqs
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
(
attn_metadata,
logits_indices,
spec_decode_metadata,
num_scheduled_tokens_np,
spec_decode_common_attn_metadata,
max_query_len,
ubatch_slices,
num_tokens_across_dp,
use_cascade_attn,
) = self._prepare_inputs(scheduler_output)
) = self._prepare_inputs(
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and ubatch_slices is None:
# Pre-compute cascade attention prefix lengths
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
scheduler_output.num_common_prefix_blocks,
)
# TODO(lucas): move cudagraph dispatching here:
# https://github.com/vllm-project/vllm/issues/23789
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata(
total_num_scheduled_tokens=total_num_scheduled_tokens,
max_num_scheduled_tokens=max_num_scheduled_tokens,
num_reqs=num_reqs,
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
@ -2532,16 +2606,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
batch_descriptor = BatchDescriptor(
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
self.cudagraph_dispatcher.dispatch(
batch_descriptor,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
)
# Set cudagraph mode to none if calc_kv_scales is true.
@ -3437,10 +3514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
if create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
@ -3456,55 +3529,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups
):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
seq_lens=self.seq_lens.gpu[:num_reqs],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=self.max_model_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id
].get_device_tensor(num_reqs),
slot_mapping=self.input_batch.block_table[
kv_cache_group_id
].slot_mapping.gpu[:num_tokens],
causal=True,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata
)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = attn_group.get_metadata_builder(
ubatch_id=ubid
).build_for_cudagraph_capture(common_attn_metadata)
for layer_name in attn_group.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
assert type(attn_metadata) is dict
metadata_builder = attn_group.get_metadata_builder()
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
common_attn_metadata
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
attn_metadata, _ = self._build_attention_metadata(
total_num_scheduled_tokens=num_tokens,
max_num_scheduled_tokens=max_query_len,
num_reqs=num_reqs,
ubatch_slices=ubatch_slices,
for_cudagraph_capture=True,
)
with self.maybe_dummy_run_with_lora(
self.lora_config,
@ -4478,9 +4509,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
list[int]: List of kernel block sizes for each cache group.
"""
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups
):
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
@ -4492,7 +4521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id]
attn_groups = self.attn_groups[kv_cache_gid]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self.select_common_block_size(
kv_manager_block_size, attn_groups