[v1] Add Whisper model support (encoder-decoder) (#21088)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Russell Bryant 2025-09-10 16:53:35 -04:00 committed by GitHub
parent 4db4426404
commit 37e8182bfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 429 additions and 92 deletions

View File

@ -321,7 +321,6 @@ steps:
- python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py
@ -644,7 +643,7 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental]
@ -818,7 +817,8 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently.

View File

@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments.
NOTE: This example is not yet supported in V1.
"""
import argparse

View File

@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
import os
import time
from collections.abc import Sequence
from dataclasses import asdict
@ -130,6 +131,8 @@ def run_mllama():
def run_whisper():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo",
max_model_len=448,

View File

@ -63,6 +63,7 @@ def clear_cache():
current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
@pytest.mark.skip(reason="bart not supported in V1")
def test_encoder_decoder_e2e(
hf_runner,
vllm_runner,

View File

@ -30,6 +30,7 @@ async def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="bart is not yet supported in V1")
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",

View File

@ -178,6 +178,7 @@ def run_test(
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.skip(reason="bart not supported in V1")
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
@pytest.mark.skip(reason="bart not supported in V1")
def test_models_distributed(hf_runner, vllm_runner,
example_encoder_decoder_prompts,
distributed_executor_backend, model, dtype,

View File

@ -122,8 +122,7 @@ def run_test(
@pytest.mark.core_model
@pytest.mark.parametrize(
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None:
run_test(

View File

@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"Florence2ForConditionalGeneration": "not supported in V1",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",

View File

@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
if model_arch == "Florence2ForConditionalGeneration":
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest.skip("Skipping Florence2ForConditionalGeneration")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,

View File

@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
]

View File

@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import numpy as np
import torch
from transformers import CacheConfig
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
vllm_config.model_config)
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len,
dtype=torch.int64,
device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
new_metadata.max_seq_len = max_encoder_len
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device="cpu",
)
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
self.kv_cache_spec, self.device)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_cross_attention_backend(
underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER")
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs)

View File

@ -8,6 +8,7 @@ import enum
import hashlib
import inspect
import json
import os
import textwrap
import warnings
from collections.abc import Mapping
@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
@ -3509,16 +3511,33 @@ class VllmConfig:
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True):
elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both.")
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
!= "spawn"):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:

View File

@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module):
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "whisper_encoder"),
is_standalone_encoder=True,
init_in_fp32=True)
mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2,

View File

@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)
@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema):
TensorShape("b", "nmb", "t")]
class WhisperEncoderAttention(MultiHeadAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Input shape: batch_size x seq_len x hidden_size
or seq_len x hidden_size
"""
is_2d = query.dim() == 2
if is_2d:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# Call the parent forward method
out = super().forward(query, key, value)
if is_2d:
out = out.squeeze(0)
return out
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int):
@ -144,7 +173,6 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
standalone_encoder: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
@ -180,14 +208,25 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
if standalone_encoder:
self.attn = MultiHeadAttention(
if attn_type == AttentionType.ENCODER:
self.attn = WhisperEncoderAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
else:
elif self.attn_type == AttentionType.ENCODER_DECODER:
self.attn = CrossAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention(
self.num_heads,
self.head_dim,
@ -332,11 +371,7 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP(
@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers",
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
@ -752,7 +782,7 @@ class WhisperMultiModalProcessor(
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal, SupportsV0Only):
SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(audio_input["input_features"])
return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once
# encoder/decoder support is implemented in V1.
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.
return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input(

View File

@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
)
}
if quant_config:

View File

@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder

View File

@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)

View File

@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)

View File

@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)

View File

@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,

View File

@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,

View File

@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None

View File

@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,

View File

@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config

View File

@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(

View File

@ -72,6 +72,9 @@ class CommonAttentionMetadata:
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
@dataclass
class UbatchSlice:
@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@abstractmethod
def build(self,

View File

@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0

View File

@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface):
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
# projector if needed) for MM models as well as encoder-decoder
# transformers.
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
@ -775,12 +775,16 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache.
continue
# The same encoder input has already been scheduled in the current
# step.
if not self.is_encoder_decoder:
# We are not using the encoder cache for encoder-decoder models,
# yet.
if request.mm_hashes[i] in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the
# current step.
continue
if self.encoder_cache_manager.check_and_update_cache(request, i):
if self.encoder_cache_manager.check_and_update_cache(
request, i):
# The encoder input is already computed and cached from a
# previous step.
continue
@ -1047,7 +1051,13 @@ class Scheduler(SchedulerInterface):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
if self.is_encoder_decoder and request.num_computed_tokens > 0:
# With Whisper, as soon as we've generated a single token,
# we know we're done with the encoder input. Cross Attention
# KVs have been calculated and cached already.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
elif start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(

View File

@ -325,7 +325,6 @@ class Processor:
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
@ -384,10 +383,6 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):

View File

@ -61,12 +61,16 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
@ -208,6 +212,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
if self.model_config.is_encoder_decoder:
# Maximum length of the encoder input, only for encoder-decoder
# models.
self.max_encoder_len = self.mm_registry.\
get_encdec_max_encoder_len(model_config)
else:
self.max_encoder_len = 0
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@ -265,7 +277,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
# We need to use the encoder length for encoder-decoer
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
@ -798,6 +812,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])
def _get_encoder_seq_lens(
self,
scheduler_output: "SchedulerOutput",
kv_cache_spec: KVCacheSpec,
num_reqs: int,
) -> Optional[np.ndarray]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
return encoder_seq_lens
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@ -937,6 +969,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
@ -981,6 +1015,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=encoder_seq_lens,
)
if self.speculative_config and \
@ -1253,10 +1288,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
return logits_indices_padded
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
def _batch_mm_kwargs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
"""Batch multimodal kwargs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_kwargs, req_ids_pos) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
return [], []
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info)
@ -1270,6 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
return mm_kwargs, mm_hashes_pos
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output)
if not mm_kwargs:
return
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
@ -1360,6 +1419,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds.append(mm_embeds_item)
return mm_embeds
def _extract_encoder_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, torch.Tensor]:
"""Extract encoder inputs for encoder-decoder models.
This method extracts multimodal input features from scheduled encoder
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}
# Group MM kwargs by modality and extract features
encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# input_features=...)
encoder_features.update(mm_kwargs_group)
return encoder_features
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper):
@ -1631,7 +1719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.supports_mm_inputs and get_pp_group().is_first_rank:
if (self.supports_mm_inputs and get_pp_group().is_first_rank
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
@ -1673,6 +1762,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
if (self.model_config.is_encoder_decoder
and scheduler_output.scheduled_encoder_inputs):
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
model_kwargs.update(encoder_inputs)
return (
num_scheduled_tokens,
num_input_tokens,
@ -2591,17 +2685,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs:
model_kwargs = self._init_model_kwargs(num_tokens)
if (self.supports_mm_inputs
and not self.model_config.is_encoder_decoder):
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = {
**self._init_model_kwargs(num_tokens),
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),
}
else:
input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens]
@ -2823,7 +2918,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_budget = self.mm_budget
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
@ -3170,7 +3264,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
@ -3443,7 +3537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec = EncoderOnlyAttentionSpec(
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
@ -3485,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
@ -3513,12 +3606,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")

View File

@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
@ -269,6 +270,16 @@ def bind_kv_cache(
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])