[Core] Generalize Encoder-Decoder seq_lens computation to avoid Whisper hardcoded logic (#29268)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Nicolò Lucchesi 2025-11-25 12:32:11 +01:00 committed by GitHub
parent de6889946b
commit 798e87db5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 34 deletions

View File

@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
"""Gets the max number of encoder input tokens from the config."""
sc = vllm_config.scheduler_config
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
"max_num_encoder_input_tokens must be int for enc-dec models"
)
return sc.max_num_encoder_input_tokens
def _get_cross_slot_mapping(
encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
@ -93,23 +84,32 @@ def create_cross_attention_backend(
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
)
if num_cache_decodes > 0:
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
)
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs,),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs,),
max_encoder_len,
dtype=torch.int32,
device="cpu",
# seq_lens is provided by model runner: initial encoder input length is
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens,
new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,

View File

@ -89,7 +89,8 @@ class CommonAttentionMetadata:
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: np.ndarray | None = None
encoder_seq_lens: torch.Tensor | None = None
encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None

View File

@ -475,6 +475,7 @@ class GPUModelRunner(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
@ -1202,21 +1203,35 @@ class GPUModelRunner(
def _get_encoder_seq_lens(
self,
scheduled_encoder_inputs: dict[str, list[int]],
num_scheduled_tokens: dict[str, int],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
) -> np.ndarray | None:
) -> tuple[torch.Tensor | None, np.ndarray | None]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None
return None, None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduled_encoder_inputs:
for req_id in num_scheduled_tokens:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
req_state = self.requests[req_id]
if req_state.mm_features is None:
self.encoder_seq_lens.np[req_index] = 0
continue
return encoder_seq_lens
# Get the total number of encoder input tokens for running encoder requests
# whether encoding is finished or not so that cross-attention knows how
# many encoder tokens to attend to.
encoder_input_tokens = sum(
feature.mm_position.length for feature in req_state.mm_features
)
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
self.encoder_seq_lens.copy_to_gpu(num_reqs)
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs]
return encoder_seq_lens, encoder_seq_lens_cpu
def _prepare_inputs(
self,
@ -1482,7 +1497,7 @@ class GPUModelRunner(
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
for_cudagraph_capture: bool = False,
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
@ -1547,8 +1562,8 @@ class GPUModelRunner(
for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups
):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduled_encoder_inputs or {},
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs,
)
@ -1591,6 +1606,7 @@ class GPUModelRunner(
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
)
@ -2828,7 +2844,7 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)