mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[Core] Generalize Encoder-Decoder seq_lens computation to avoid Whisper hardcoded logic (#29268)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
parent
de6889946b
commit
798e87db5c
@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
||||
"""Gets the max number of encoder input tokens from the config."""
|
||||
sc = vllm_config.scheduler_config
|
||||
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
|
||||
"max_num_encoder_input_tokens must be int for enc-dec models"
|
||||
)
|
||||
return sc.max_num_encoder_input_tokens
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(
|
||||
encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
@ -93,23 +84,32 @@ def create_cross_attention_backend(
|
||||
) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||
num_cache_decodes = (
|
||||
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
|
||||
)
|
||||
if num_cache_decodes > 0:
|
||||
# CrossAttn KV cache has already been populated on first decoder step,
|
||||
# skip slot_mapping calculation for requests that do not need
|
||||
# reshape_and_cache.
|
||||
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
|
||||
new_metadata.encoder_seq_lens_cpu = np.where(
|
||||
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
new_metadata.seq_lens = torch.full(
|
||||
(new_metadata.num_reqs,),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
new_metadata.seq_lens_cpu = torch.full(
|
||||
(new_metadata.num_reqs,),
|
||||
max_encoder_len,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
# seq_lens is provided by model runner: initial encoder input length is
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata.seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens,
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
|
||||
@ -89,7 +89,8 @@ class CommonAttentionMetadata:
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: np.ndarray | None = None
|
||||
encoder_seq_lens: torch.Tensor | None = None
|
||||
encoder_seq_lens_cpu: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
|
||||
@ -475,6 +475,7 @@ class GPUModelRunner(
|
||||
self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
@ -1202,21 +1203,35 @@ class GPUModelRunner(
|
||||
|
||||
def _get_encoder_seq_lens(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_reqs: int,
|
||||
) -> np.ndarray | None:
|
||||
) -> tuple[torch.Tensor | None, np.ndarray | None]:
|
||||
if not isinstance(kv_cache_spec, CrossAttentionSpec):
|
||||
return None
|
||||
return None, None
|
||||
|
||||
# Build encoder_seq_lens array mapping request indices to
|
||||
# encoder lengths for inputs scheduled in this batch
|
||||
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id in scheduled_encoder_inputs:
|
||||
for req_id in num_scheduled_tokens:
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
encoder_seq_lens[req_index] = self.max_encoder_len
|
||||
req_state = self.requests[req_id]
|
||||
if req_state.mm_features is None:
|
||||
self.encoder_seq_lens.np[req_index] = 0
|
||||
continue
|
||||
|
||||
return encoder_seq_lens
|
||||
# Get the total number of encoder input tokens for running encoder requests
|
||||
# whether encoding is finished or not so that cross-attention knows how
|
||||
# many encoder tokens to attend to.
|
||||
encoder_input_tokens = sum(
|
||||
feature.mm_position.length for feature in req_state.mm_features
|
||||
)
|
||||
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
|
||||
|
||||
self.encoder_seq_lens.copy_to_gpu(num_reqs)
|
||||
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
|
||||
encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs]
|
||||
|
||||
return encoder_seq_lens, encoder_seq_lens_cpu
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
@ -1482,7 +1497,7 @@ class GPUModelRunner(
|
||||
logits_indices: torch.Tensor | None = None,
|
||||
use_spec_decode: bool = False,
|
||||
for_cudagraph_capture: bool = False,
|
||||
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
|
||||
num_scheduled_tokens: dict[str, int] | None = None,
|
||||
cascade_attn_prefix_lens: list[list[int]] | None = None,
|
||||
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
|
||||
"""
|
||||
@ -1547,8 +1562,8 @@ class GPUModelRunner(
|
||||
for kv_cache_gid, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups
|
||||
):
|
||||
encoder_seq_lens = self._get_encoder_seq_lens(
|
||||
scheduled_encoder_inputs or {},
|
||||
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
|
||||
num_scheduled_tokens or {},
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs,
|
||||
)
|
||||
@ -1591,6 +1606,7 @@ class GPUModelRunner(
|
||||
num_logits_indices=num_logits_indices,
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
|
||||
)
|
||||
@ -2828,7 +2844,7 @@ class GPUModelRunner(
|
||||
ubatch_slices=ubatch_slices,
|
||||
logits_indices=logits_indices,
|
||||
use_spec_decode=use_spec_decode,
|
||||
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
|
||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
|
||||
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user