[Core] Whisper Enable Encoder Batching (#29421)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-12-11 22:06:51 +01:00 committed by GitHub
parent 90d6cf921f
commit 0efd9f867c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 25 deletions

View File

@ -539,6 +539,11 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
if self.is_encoder_decoder:
self.mm_processor_cache_gb = 0
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
# Init multimodal config if needed
if self._model_info.supports_multimodal:
if (

View File

@ -750,27 +750,17 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
if (
self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if (
self.kv_events_config is not None

View File

@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = []
input_is_batched = False
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
embeds.dtype
)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched:
# Models using WhisperEncoder may handle batching internally.
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
assert len(encoder_outputs) in (0, 1)
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])]
# Split concatenated encoder outputs into one tensor per audio input
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
# The assumption is we can only process whole mm items (audios)
return enc_output.unbind(dim=0)
def embed_input_ids(
self,

View File

@ -341,3 +341,56 @@ def compute_mm_encoder_budget(
)
return encoder_compute_budget, encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class EncoderDecoderCacheManager(EncoderCacheManager):
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False
def can_allocate(
self,
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
def free(self, request: Request) -> None:
for input_id in range(len(request.mm_features)):
self.free_encoder_input(request, input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return set(range(len(request.mm_features)))
def get_freed_mm_hashes(self) -> list[str]:
freed = self.freed
self.freed = []
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens

View File

@ -27,6 +27,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
compute_encoder_budget,
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
@ -181,7 +182,11 @@ class Scheduler(SchedulerInterface):
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size)
self.encoder_cache_manager = (
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
if self.is_encoder_decoder
else EncoderCacheManager(cache_size=encoder_cache_size)
)
speculative_config = vllm_config.speculative_config
self.use_eagle = False