diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index 5b44c7e3e7ec..068fd0a0eb7d 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec logger = init_logger(__name__) -def _get_max_encoder_len(vllm_config: "VllmConfig") -> int: - """Gets the max number of encoder input tokens from the config.""" - sc = vllm_config.scheduler_config - assert sc and isinstance(sc.max_num_encoder_input_tokens, int), ( - "max_num_encoder_input_tokens must be int for enc-dec models" - ) - return sc.max_num_encoder_input_tokens - - def _get_cross_slot_mapping( encoder_seq_lens: np.ndarray, block_table_tensor: torch.Tensor, @@ -93,23 +84,32 @@ def create_cross_attention_backend( ) -> AttentionMetadata: new_metadata = copy(common_attn_metadata) new_metadata.causal = False - max_encoder_len = _get_max_encoder_len(self.vllm_config) + max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len + # Any computed tokens indicated decode step>1 (no chunked prefill) + num_cache_decodes = ( + (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item() + ) + if num_cache_decodes > 0: + # CrossAttn KV cache has already been populated on first decoder step, + # skip slot_mapping calculation for requests that do not need + # reshape_and_cache. + num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy() + new_metadata.encoder_seq_lens_cpu = np.where( + num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu + ) - new_metadata.seq_lens = torch.full( - (new_metadata.num_reqs,), - max_encoder_len, - dtype=torch.int32, - device=self.device, - ) - new_metadata.seq_lens_cpu = torch.full( - (new_metadata.num_reqs,), - max_encoder_len, - dtype=torch.int32, - device="cpu", + # seq_lens is provided by model runner: initial encoder input length is + # needed here to know how many tokens to attend to from the cached + # cross-attention KV cache. + new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens + new_metadata.seq_lens_cpu = torch.from_numpy( + common_attn_metadata.encoder_seq_lens_cpu ) + + # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here new_metadata.slot_mapping = _get_cross_slot_mapping( - new_metadata.encoder_seq_lens, + new_metadata.encoder_seq_lens_cpu, new_metadata.block_table_tensor, self.kv_cache_spec, self.device, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 540a8e2b1d01..cebfe8a3ff04 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -89,7 +89,8 @@ class CommonAttentionMetadata: num_logits_indices: int | None = None # Needed by CrossAttentionBuilder - encoder_seq_lens: np.ndarray | None = None + encoder_seq_lens: torch.Tensor | None = None + encoder_seq_lens_cpu: np.ndarray | None = None dcp_local_seq_lens: torch.Tensor | None = None dcp_local_seq_lens_cpu: torch.Tensor | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 74fd2a1e2a2c..0ce6c4a3204b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -475,6 +475,7 @@ class GPUModelRunner( self.max_num_reqs + 1, dtype=torch.int32 ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) if self.dcp_world_size > 1: self.dcp_local_seq_lens = self._make_buffer( self.max_num_reqs, dtype=torch.int32 @@ -1202,21 +1203,35 @@ class GPUModelRunner( def _get_encoder_seq_lens( self, - scheduled_encoder_inputs: dict[str, list[int]], + num_scheduled_tokens: dict[str, int], kv_cache_spec: KVCacheSpec, num_reqs: int, - ) -> np.ndarray | None: + ) -> tuple[torch.Tensor | None, np.ndarray | None]: if not isinstance(kv_cache_spec, CrossAttentionSpec): - return None + return None, None # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch - encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) - for req_id in scheduled_encoder_inputs: + for req_id in num_scheduled_tokens: req_index = self.input_batch.req_id_to_index[req_id] - encoder_seq_lens[req_index] = self.max_encoder_len + req_state = self.requests[req_id] + if req_state.mm_features is None: + self.encoder_seq_lens.np[req_index] = 0 + continue - return encoder_seq_lens + # Get the total number of encoder input tokens for running encoder requests + # whether encoding is finished or not so that cross-attention knows how + # many encoder tokens to attend to. + encoder_input_tokens = sum( + feature.mm_position.length for feature in req_state.mm_features + ) + self.encoder_seq_lens.np[req_index] = encoder_input_tokens + + self.encoder_seq_lens.copy_to_gpu(num_reqs) + encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs] + encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs] + + return encoder_seq_lens, encoder_seq_lens_cpu def _prepare_inputs( self, @@ -1482,7 +1497,7 @@ class GPUModelRunner( logits_indices: torch.Tensor | None = None, use_spec_decode: bool = False, for_cudagraph_capture: bool = False, - scheduled_encoder_inputs: dict[str, list[int]] | None = None, + num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ @@ -1547,8 +1562,8 @@ class GPUModelRunner( for kv_cache_gid, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups ): - encoder_seq_lens = self._get_encoder_seq_lens( - scheduled_encoder_inputs or {}, + encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( + num_scheduled_tokens or {}, kv_cache_group.kv_cache_spec, num_reqs, ) @@ -1591,6 +1606,7 @@ class GPUModelRunner( num_logits_indices=num_logits_indices, causal=True, encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_cpu=encoder_seq_lens_cpu, dcp_local_seq_lens=dcp_local_seq_lens, dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu, ) @@ -2828,7 +2844,7 @@ class GPUModelRunner( ubatch_slices=ubatch_slices, logits_indices=logits_indices, use_spec_decode=use_spec_decode, - scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) )