From 636efd10a5b472c6016f744b30adeb12514c0acf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 9 Nov 2025 13:51:43 -0500 Subject: [PATCH] [Core] Separate out attention metadata building logic from prepare inputs (#26764) Signed-off-by: Lucas Wilkinson --- vllm/v1/worker/gpu_model_runner.py | 339 ++++++++++++++++------------- 1 file changed, 184 insertions(+), 155 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3709710ef42e7..de9f32687635e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1054,7 +1054,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_encoder_seq_lens( self, - scheduler_output: "SchedulerOutput", + scheduled_encoder_inputs: dict[str, list[int]], kv_cache_spec: KVCacheSpec, num_reqs: int, ) -> np.ndarray | None: @@ -1064,31 +1064,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) - for req_id in scheduler_output.scheduled_encoder_inputs: + for req_id in scheduled_encoder_inputs: req_index = self.input_batch.req_id_to_index[req_id] encoder_seq_lens[req_index] = self.max_encoder_len return encoder_seq_lens def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" + self, + scheduler_output: "SchedulerOutput", + num_scheduled_tokens: np.ndarray, + max_num_scheduled_tokens: int, ) -> tuple[ - PerLayerAttnMetadata, torch.Tensor, SpecDecodeMetadata | None, - np.ndarray, - CommonAttentionMetadata | None, - int, UBatchSlices | None, torch.Tensor | None, - bool, ]: """ :return: tuple[ - attn_metadata: layer-to-attention_metadata mapping, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, use_cascade_attn + ubatch_slices, num_tokens_across_dp, ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1100,12 +1096,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) - # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) @@ -1232,8 +1222,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() - seq_lens = self.seq_lens.gpu[:num_reqs] - max_seq_len = self.seq_lens.np[:num_reqs].max().item() num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) @@ -1305,11 +1293,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() - logits_indices_padded = None - if self.cache_config.kv_sharing_fast_prefill: - logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices + # Hot-Swap lora model + if self.lora_config: + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens ) + self.set_active_loras( + self.input_batch, num_scheduled_tokens, num_sampled_tokens + ) + + return ( + logits_indices, + spec_decode_metadata, + ubatch_slices, + num_tokens_across_dp, + ) + + def _build_attention_metadata( + self, + total_num_scheduled_tokens: int, + max_num_scheduled_tokens: int, + num_reqs: int, + ubatch_slices: UBatchSlices | None = None, + logits_indices: torch.Tensor | None = None, + use_spec_decode: bool = False, + for_cudagraph_capture: bool = False, + scheduled_encoder_inputs: dict[str, list[int]] | None = None, + cascade_attn_prefix_lens: list[list[int]] | None = None, + ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: + """ + :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + """ + logits_indices_padded = None + num_logits_indices = 0 + if logits_indices is not None: + num_logits_indices = logits_indices.size(0) + if self.cache_config.kv_sharing_fast_prefill: + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) # update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: @@ -1324,15 +1347,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))] - use_cascade_attn = False - # Used in the below loop. + # Used in the below loop + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] + seq_lens = self.seq_lens.gpu[:num_reqs] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs ] + dcp_local_seq_lens = ( + self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None + ) spec_decode_common_attn_metadata = None + + if for_cudagraph_capture: + # For some attention backends (e.g. FA) with sliding window models we need + # to make sure the backend see a max_seq_len that is larger to the sliding + # window size when capturing to make sure the correct kernel is selected. + max_seq_len = self.max_model_len + else: + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs] @@ -1342,14 +1378,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( + for kv_cache_gid, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups ): encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + scheduled_encoder_inputs or {}, + kv_cache_group.kv_cache_spec, + num_reqs, ) - if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( @@ -1362,18 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int64, device=self.device, ) - num_common_prefix_blocks = 0 else: - blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs) slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) - num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ - kv_cache_group_id - ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1388,35 +1422,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), + num_logits_indices=num_logits_indices, causal=True, encoder_seq_lens=encoder_seq_lens, - dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] - if self.dcp_world_size > 1 - else None, + dcp_local_seq_lens=dcp_local_seq_lens, ) if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if ( - self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names - ): + if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: spec_decode_common_attn_metadata = common_attn_metadata else: spec_decode_common_attn_metadata = common_attn_metadata - for attn_group in self.attn_groups[kv_cache_group_id]: - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 + for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]): + cascade_attn_prefix_len = ( + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] + if cascade_attn_prefix_lens + else 0 + ) builder = attn_group.get_metadata_builder() - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - num_common_prefix_blocks, - attn_group.kv_cache_spec, - builder, - ) extra_attn_metadata_args = {} if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): @@ -1434,51 +1459,69 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for ubid, common_attn_metadata in enumerate( common_attn_metadata_list ): - attn_metadata_i = attn_group.get_metadata_builder( - ubatch_id=ubid - ).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - for layer_name in kv_cache_group_spec.layer_names: + builder = attn_group.get_metadata_builder(ubatch_id=ubid) + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_group.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i else: assert isinstance(attn_metadata, dict) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args, - ) - use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - # disable cascade attention when DBO - if ubatch_slices is not None: - use_cascade_attn = False + return attn_metadata, spec_decode_common_attn_metadata - # Hot-Swap lora model - if self.lora_config: - assert ( - np.sum(num_sampled_tokens) - <= self.vllm_config.scheduler_config.max_num_batched_tokens - ) - self.set_active_loras( - self.input_batch, num_scheduled_tokens, num_sampled_tokens - ) + def _compute_cascade_attn_prefix_lens( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: list[int], + ) -> list[list[int]] | None: + """ + :return: Optional[cascade_attn_prefix_lens] + cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``, + None if we should not use cascade attention + """ - return ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens, - spec_decode_common_attn_metadata, - max_num_scheduled_tokens, - ubatch_slices, - num_tokens_across_dp, - use_cascade_attn, - ) + use_cascade_attn = False + num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) + cascade_attn_prefix_lens: list[list[int]] = [ + [] for _ in range(num_kv_cache_groups) + ] + + for kv_cache_gid in range(num_kv_cache_groups): + for attn_group in self.attn_groups[kv_cache_gid]: + if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec): + cascade_attn_prefix_len = 0 + else: + # 0 if cascade attention should not be used + cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + num_common_prefix_blocks[kv_cache_gid], + attn_group.kv_cache_spec, + attn_group.get_metadata_builder(), + ) + cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len) + use_cascade_attn |= cascade_attn_prefix_len > 0 + + return cascade_attn_prefix_lens if use_cascade_attn else None def _compute_cascade_attn_prefix_len( self, @@ -1504,6 +1547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Returns: int: Length of common prefix in tokens. """ + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. @@ -2497,18 +2541,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "it when the requests need prompt logprobs" ) - # Prepare the decoder inputs. + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + ( - attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_across_dp, - use_cascade_attn, - ) = self._prepare_inputs(scheduler_output) + ) = self._prepare_inputs( + scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and ubatch_slices is None: + # Pre-compute cascade attention prefix lengths + # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + scheduler_output.num_common_prefix_blocks, + ) + + # TODO(lucas): move cudagraph dispatching here: + # https://github.com/vllm-project/vllm/issues/23789 + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + total_num_scheduled_tokens=total_num_scheduled_tokens, + max_num_scheduled_tokens=max_num_scheduled_tokens, + num_reqs=num_reqs, + ubatch_slices=ubatch_slices, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + ) + ) dp_rank = self.parallel_config.data_parallel_rank if ubatch_slices: @@ -2532,16 +2606,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): scheduler_output, num_input_tokens, intermediate_tensors ) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len - ) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) batch_descriptor = BatchDescriptor( num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, ) cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + self.cudagraph_dispatcher.dispatch( + batch_descriptor, + use_cascade_attn=cascade_attn_prefix_lens is not None, + ) ) # Set cudagraph mode to none if calc_kv_scales is true. @@ -3437,10 +3514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - attn_metadata = {} - if ubatch_slices is not None: - attn_metadata = [dict() for _ in range(len(ubatch_slices))] - if create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. @@ -3456,55 +3529,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups - ): - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], - seq_lens=self.seq_lens.gpu[:num_reqs], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs - ], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch.block_table[ - kv_cache_group_id - ].get_device_tensor(num_reqs), - slot_mapping=self.input_batch.block_table[ - kv_cache_group_id - ].slot_mapping.gpu[:num_tokens], - causal=True, - dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] - if self.dcp_world_size > 1 - else None, - ) - for attn_group in self.attn_groups[kv_cache_group_id]: - if ubatch_slices is not None: - common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata - ) - for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list - ): - assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = attn_group.get_metadata_builder( - ubatch_id=ubid - ).build_for_cudagraph_capture(common_attn_metadata) - for layer_name in attn_group.layer_names: - assert type(attn_metadata) is list - attn_metadata[ubid][layer_name] = attn_metadata_i - else: - assert type(attn_metadata) is dict - metadata_builder = attn_group.get_metadata_builder() - attn_metadata_i = metadata_builder.build_for_cudagraph_capture( - common_attn_metadata - ) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i + attn_metadata, _ = self._build_attention_metadata( + total_num_scheduled_tokens=num_tokens, + max_num_scheduled_tokens=max_query_len, + num_reqs=num_reqs, + ubatch_slices=ubatch_slices, + for_cudagraph_capture=True, + ) with self.maybe_dummy_run_with_lora( self.lora_config, @@ -4478,9 +4509,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): list[int]: List of kernel block sizes for each cache group. """ kernel_block_sizes = [] - for kv_cache_group_id, kv_cache_group in enumerate( - kv_cache_config.kv_cache_groups - ): + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): # All layers in the UniformTypeKVCacheSpecs have the same type, @@ -4492,7 +4521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from # all backends in the group. - attn_groups = self.attn_groups[kv_cache_group_id] + attn_groups = self.attn_groups[kv_cache_gid] kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size selected_kernel_size = self.select_common_block_size( kv_manager_block_size, attn_groups