mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 13:29:37 +08:00
[Core] Whisper Enable Encoder Batching (#29421)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
90d6cf921f
commit
0efd9f867c
@ -539,6 +539,11 @@ class ModelConfig:
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.mm_processor_cache_gb = 0
|
||||
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
|
||||
|
||||
# Init multimodal config if needed
|
||||
if self._model_info.supports_multimodal:
|
||||
if (
|
||||
|
||||
@ -750,27 +750,17 @@ class VllmConfig:
|
||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||
self._set_compile_ranges()
|
||||
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
if (
|
||||
self.model_config
|
||||
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
|
||||
@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = []
|
||||
input_is_batched = False
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
|
||||
embeds.dtype
|
||||
)
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
encoder_outputs: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
assert len(encoder_outputs) in (0, 1)
|
||||
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
|
||||
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Required as part of SupportsMultiModal interface.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
||||
# Split concatenated encoder outputs into one tensor per audio input
|
||||
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
|
||||
# The assumption is we can only process whole mm items (audios)
|
||||
return enc_output.unbind(dim=0)
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
|
||||
@ -341,3 +341,56 @@ def compute_mm_encoder_budget(
|
||||
)
|
||||
|
||||
return encoder_compute_budget, encoder_cache_size
|
||||
|
||||
|
||||
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
|
||||
# use the manager for scheduling purposes. Encoder-decoder models will eventually
|
||||
# utilize the cache and this class will fold into EncoderCacheManager, as
|
||||
# differences with MM models shrink.
|
||||
class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
self.freed: list[str] = []
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
return False
|
||||
|
||||
def can_allocate(
|
||||
self,
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
# Enough free slots
|
||||
return num_tokens <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
for input_id in range(len(request.mm_features)):
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
return set(range(len(request.mm_features)))
|
||||
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots += num_tokens
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (
|
||||
EncoderCacheManager,
|
||||
EncoderDecoderCacheManager,
|
||||
compute_encoder_budget,
|
||||
)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
@ -181,7 +182,11 @@ class Scheduler(SchedulerInterface):
|
||||
# NOTE: For the models without encoder (e.g., text-only models),
|
||||
# the encoder cache will not be initialized because cache size is 0
|
||||
# for these models.
|
||||
self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size)
|
||||
self.encoder_cache_manager = (
|
||||
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
|
||||
if self.is_encoder_decoder
|
||||
else EncoderCacheManager(cache_size=encoder_cache_size)
|
||||
)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
self.use_eagle = False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user