[v1] Add Whisper model support (encoder-decoder) (#21088)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Russell Bryant 2025-09-10 16:53:35 -04:00 committed by GitHub
parent 4db4426404
commit 37e8182bfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 429 additions and 92 deletions

View File

@ -321,7 +321,6 @@ steps:
- python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py - python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py - python3 offline_inference/basic/embed.py
@ -644,7 +643,7 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch' - pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1 - label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
@ -818,7 +817,8 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel # test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py - pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently. # this test fails consistently.

View File

@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART and mBART. encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments. This script is refactored to allow model selection via command-line arguments.
NOTE: This example is not yet supported in V1.
""" """
import argparse import argparse

View File

@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation. the explicit/implicit prompt format on enc-dec LMMs for text generation.
""" """
import os
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict from dataclasses import asdict
@ -130,6 +131,8 @@ def run_mllama():
def run_whisper(): def run_whisper():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
engine_args = EngineArgs( engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo", model="openai/whisper-large-v3-turbo",
max_model_len=448, max_model_len=448,

View File

@ -63,6 +63,7 @@ def clear_cache():
current_platform.is_cpu(), current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models" reason="CPU backend is not currently supported with encoder/decoder models"
) )
@pytest.mark.skip(reason="bart not supported in V1")
def test_encoder_decoder_e2e( def test_encoder_decoder_e2e(
hf_runner, hf_runner,
vllm_runner, vllm_runner,

View File

@ -30,6 +30,7 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="bart is not yet supported in V1")
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",

View File

@ -178,6 +178,7 @@ def run_test(
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.skip(reason="bart not supported in V1")
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
@pytest.mark.skip(reason="bart not supported in V1")
def test_models_distributed(hf_runner, vllm_runner, def test_models_distributed(hf_runner, vllm_runner,
example_encoder_decoder_prompts, example_encoder_decoder_prompts,
distributed_executor_backend, model, dtype, distributed_executor_backend, model, dtype,

View File

@ -122,8 +122,7 @@ def run_test(
@pytest.mark.core_model @pytest.mark.core_model
@pytest.mark.parametrize( @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None: def test_models(vllm_runner, model) -> None:
run_test( run_test(

View File

@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = { ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements", "MolmoForCausalLM": "incompatible requirements",
"Florence2ForConditionalGeneration": "not supported in V1",
} }
ARCH_NEEDS_EXTRAS = [ ARCH_NEEDS_EXTRAS = [
"InternVLChatModel", "InternVLChatModel",

View File

@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3. # L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
if model_arch == "Florence2ForConditionalGeneration":
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest.skip("Skipping Florence2ForConditionalGeneration")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,

View File

@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder "facebook/bart-large-cnn", # encoder decoder
] ]

View File

@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import numpy as np
import torch
from transformers import CacheConfig
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
vllm_config.model_config)
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len,
dtype=torch.int64,
device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
new_metadata.max_seq_len = max_encoder_len
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device="cpu",
)
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
self.kv_cache_spec, self.device)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_cross_attention_backend(
underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER")
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs)

View File

@ -8,6 +8,7 @@ import enum
import hashlib import hashlib
import inspect import inspect
import json import json
import os
import textwrap import textwrap
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config, ConfigFormat, get_config, get_hf_image_processor_config,
@ -3509,16 +3511,33 @@ class VllmConfig:
disable_chunked_prefill_reasons: list[str] = [] disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config: if self.model_config:
pooling_type = self.model_config.pooler_config.pooling_type if self.model_config.pooler_config:
if pooling_type is None or pooling_type.lower() != "last": pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append( disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked " "Encoder-decoder models do not support chunked prefill nor"
"prefill and prefix caching; disabling both.") " prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True): if (self.model_config.architecture
disable_chunked_prefill_reasons.append( == "WhisperForConditionalGeneration"
"Only models using causal attention supports chunked " and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
"prefill and prefix caching; disabling both.") != "spawn"):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons: if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons:

View File

@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module):
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "whisper_encoder"), prefix, "whisper_encoder"),
is_standalone_encoder=True,
init_in_fp32=True) init_in_fp32=True)
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2, num_frequency_bins=1 + self.config.window_size // 2,

View File

@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig) VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only) SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers) make_layers)
@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema):
TensorShape("b", "nmb", "t")] TensorShape("b", "nmb", "t")]
class WhisperEncoderAttention(MultiHeadAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Input shape: batch_size x seq_len x hidden_size
or seq_len x hidden_size
"""
is_2d = query.dim() == 2
if is_2d:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# Call the parent forward method
out = super().forward(query, key, value)
if is_2d:
out = out.squeeze(0)
return out
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int): def __init__(self, num_positions: int, embedding_dim: int):
@ -144,7 +173,6 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
standalone_encoder: bool = False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -180,14 +208,25 @@ class WhisperAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
if standalone_encoder: if attn_type == AttentionType.ENCODER:
self.attn = MultiHeadAttention( self.attn = WhisperEncoderAttention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
else: elif self.attn_type == AttentionType.ENCODER_DECODER:
self.attn = CrossAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -332,11 +371,7 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module): class WhisperEncoderLayer(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP( self.mlp = WhisperMLP(
@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False): init_in_fp32: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
embed_dim = config.d_model embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim) self.embed_scale = (math.sqrt(embed_dim)
@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers, config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers"),
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
@ -752,7 +782,7 @@ class WhisperMultiModalProcessor(
info=WhisperProcessingInfo, info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder) dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal, SupportsV0Only): SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"self_attn.qkv_proj": [ "self_attn.qkv_proj": [
"self_attn.q_proj", "self_attn.q_proj",
@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings: **kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal. # Required as part of SupportsMultiModal interface.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(audio_input["input_features"]) return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since # This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once # Whisper does not have encoder text tokens.
# encoder/decoder support is implemented in V1.
return self.model.decoder.get_input_embeddings(input_ids) return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(

View File

@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_attention_heads=encoder_args["n_heads"], encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"], vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"], max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
) )
} }
if quant_config: if quant_config:

View File

@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None: vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
# For reorder # For reorder

View File

@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)

View File

@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape) self._decode_wrapper = None # Wrapper for decode (general shape)

View File

@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)

View File

@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,

View File

@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs, self.vllm_config.scheduler_config.max_num_seqs,
@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
m.max_query_len = 1 # decode-only m.max_query_len = 1 # decode-only
return self.build(0, m) return self.build(0, m)

View File

@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)
@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
self.parallel_config) self.parallel_config)
self.headdim = self.model_config.get_head_size() self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_sliding_window: Optional[tuple[int, int]] = None

View File

@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,

View File

@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
): ):
self.kv_cache_spec = kv_cache_spec super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config spec_config = vllm_config.speculative_config

View File

@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
model_config = vllm_config.model_config model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads( self.num_heads_q = model_config.get_num_attention_heads(

View File

@ -72,6 +72,9 @@ class CommonAttentionMetadata:
logits_indices_padded: Optional[torch.Tensor] = None logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
@dataclass @dataclass
class UbatchSlice: class UbatchSlice:
@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@abstractmethod @abstractmethod
def build(self, def build(self,

View File

@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
): ):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE assert XFORMERS_AVAILABLE
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self._num_decodes = 0 self._num_decodes = 0
self._num_decode_tokens = 0 self._num_decode_tokens = 0

View File

@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface):
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also # projector if needed) for MM models as well as encoder-decoder
# has the Transformer architecture (e.g., ViT). # transformers.
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models), # NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0 # the encoder cache will not be initialized because cache size is 0
@ -775,15 +775,19 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache. # in the decoder's KV cache.
continue continue
# The same encoder input has already been scheduled in the current if not self.is_encoder_decoder:
# step. # We are not using the encoder cache for encoder-decoder models,
if request.mm_hashes[i] in mm_hashes_to_schedule: # yet.
continue if request.mm_hashes[i] in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the
# current step.
continue
if self.encoder_cache_manager.check_and_update_cache(request, i): if self.encoder_cache_manager.check_and_update_cache(
# The encoder input is already computed and cached from a request, i):
# previous step. # The encoder input is already computed and cached from a
continue # previous step.
continue
# If no encoder input chunking is allowed, we do not want to # If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would # partially schedule a multimodal item. If the scheduled range would
@ -1047,7 +1051,13 @@ class Scheduler(SchedulerInterface):
mm_positions = request.mm_positions[input_id] mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset start_pos = mm_positions.offset
num_tokens = mm_positions.length num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens: if self.is_encoder_decoder and request.num_computed_tokens > 0:
# With Whisper, as soon as we've generated a single token,
# we know we're done with the encoder input. Cross Attention
# KVs have been calculated and cached already.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
elif start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input( self.encoder_cache_manager.free_encoder_input(

View File

@ -325,7 +325,6 @@ class Processor:
) -> tuple[Optional[str], EngineCoreRequest]: ) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params, lora_request) self._validate_params(params, lora_request)
if trace_headers is not None: if trace_headers is not None:
@ -384,10 +383,6 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
sampling_params = None sampling_params = None
pooling_params = None pooling_params = None
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):

View File

@ -61,12 +61,16 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
CrossAttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec) MambaSpec, SlidingWindowSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors, DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput) ModelRunnerOutput, SamplerOutput)
@ -208,6 +212,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config) model_config)
if self.model_config.is_encoder_decoder:
# Maximum length of the encoder input, only for encoder-decoder
# models.
self.max_encoder_len = self.mm_registry.\
get_encdec_max_encoder_len(model_config)
else:
self.max_encoder_len = 0
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@ -265,7 +277,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the block_sizes in the kv cache config. # the block_sizes in the kv cache config.
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, # We need to use the encoder length for encoder-decoer
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
@ -798,6 +812,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src=self.input_batch.prev_sampled_token_ids[ src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0]) prev_common_req_indices_tensor, 0])
def _get_encoder_seq_lens(
self,
scheduler_output: "SchedulerOutput",
kv_cache_spec: KVCacheSpec,
num_reqs: int,
) -> Optional[np.ndarray]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
return encoder_seq_lens
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -937,6 +969,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same group share the same metadata. # in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups): self.kv_cache_config.kv_cache_groups):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
if isinstance(kv_cache_group_spec.kv_cache_spec, if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec): EncoderOnlyAttentionSpec):
@ -981,6 +1015,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded=logits_indices_padded, logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0), num_logits_indices=logits_indices.size(0),
causal=True, causal=True,
encoder_seq_lens=encoder_seq_lens,
) )
if self.speculative_config and \ if self.speculative_config and \
@ -1253,10 +1288,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
return logits_indices_padded return logits_indices_padded
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _batch_mm_kwargs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
"""Batch multimodal kwargs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_kwargs, req_ids_pos) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return [], []
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info) # list of tuple (mm_hash, position_info)
@ -1270,6 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hashes_pos.append( mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id])) (mm_hash, req_state.mm_positions[mm_input_id]))
return mm_kwargs, mm_hashes_pos
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output)
if not mm_kwargs:
return
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order. # we process it separately to preserve item order.
@ -1360,6 +1419,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds.append(mm_embeds_item) mm_embeds.append(mm_embeds_item)
return mm_embeds return mm_embeds
def _extract_encoder_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, torch.Tensor]:
"""Extract encoder inputs for encoder-decoder models.
This method extracts multimodal input features from scheduled encoder
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}
# Group MM kwargs by modality and extract features
encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# input_features=...)
encoder_features.update(mm_kwargs_group)
return encoder_features
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper. # get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper): if isinstance(self.model, CUDAGraphWrapper):
@ -1631,7 +1719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
if self.supports_mm_inputs and get_pp_group().is_first_rank: if (self.supports_mm_inputs and get_pp_group().is_first_rank
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
@ -1673,6 +1762,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
if (self.model_config.is_encoder_decoder
and scheduler_output.scheduled_encoder_inputs):
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
model_kwargs.update(encoder_inputs)
return ( return (
num_scheduled_tokens, num_scheduled_tokens,
num_input_tokens, num_input_tokens,
@ -2591,17 +2685,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens, remove_lora): num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs: model_kwargs = self._init_model_kwargs(num_tokens)
if (self.supports_mm_inputs
and not self.model_config.is_encoder_decoder):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens] inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = { model_kwargs = {
**self._init_model_kwargs(num_tokens), **model_kwargs,
**self._dummy_mm_kwargs(num_reqs), **self._dummy_mm_kwargs(num_reqs),
} }
else: else:
input_ids = self.input_ids.gpu[:num_tokens] input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens] positions = self.mrope_positions.gpu[:, :num_tokens]
@ -2823,7 +2918,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0: if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
@ -3170,7 +3264,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"for more details.") "for more details.")
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
@ -3443,7 +3537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY: if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec = EncoderOnlyAttentionSpec( attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
@ -3485,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like # TODO(lucas): move the attention specs into the model layers like
# the attention backends # the attention backends
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
@ -3513,12 +3606,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
use_mla=use_mla) use_mla=use_mla)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.
continue continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else: else:
raise ValueError( raise ValueError(
f"Unknown attention type: {attn_module.attn_type}") f"Unknown attention type: {attn_module.attn_type}")

View File

@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
@ -269,7 +270,17 @@ def bind_kv_cache(
# One typical case is encoder-decoder model, e.g., bart. # One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer # The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index. # has different layer_name but the same layer_index.
raise NotImplementedError
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0] layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name]) runner_kv_caches.append(kv_caches[layer_name])