mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-14 19:09:07 +08:00
[Core] Whisper Enable Encoder Batching (#29421)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
90d6cf921f
commit
0efd9f867c
@ -539,6 +539,11 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.original_max_model_len = self.max_model_len
|
self.original_max_model_len = self.max_model_len
|
||||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
self.mm_processor_cache_gb = 0
|
||||||
|
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
|
||||||
|
|
||||||
# Init multimodal config if needed
|
# Init multimodal config if needed
|
||||||
if self._model_info.supports_multimodal:
|
if self._model_info.supports_multimodal:
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -750,27 +750,17 @@ class VllmConfig:
|
|||||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||||
self._set_compile_ranges()
|
self._set_compile_ranges()
|
||||||
|
|
||||||
if self.model_config and self.model_config.is_encoder_decoder:
|
if (
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
self.model_config
|
||||||
|
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Whisper is known to have issues with "
|
||||||
|
"forked workers. If startup is hanging, "
|
||||||
|
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||||
|
"to 'spawn'."
|
||||||
)
|
)
|
||||||
logger.debug(
|
|
||||||
"Encoder-decoder model detected: setting "
|
|
||||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
|
||||||
self.scheduler_config.max_num_encoder_input_tokens,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
|
||||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Whisper is known to have issues with "
|
|
||||||
"forked workers. If startup is hanging, "
|
|
||||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
|
||||||
"to 'spawn'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.kv_events_config is not None
|
self.kv_events_config is not None
|
||||||
|
|||||||
@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
|
input_is_batched = False
|
||||||
for features in input_features:
|
for features in input_features:
|
||||||
embeds = nn.functional.gelu(self.conv1(features))
|
embeds = nn.functional.gelu(self.conv1(features))
|
||||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||||
@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
|
|||||||
embeds.dtype
|
embeds.dtype
|
||||||
)
|
)
|
||||||
hidden_states.append(embeds)
|
hidden_states.append(embeds)
|
||||||
hidden_states = torch.cat(hidden_states)
|
input_is_batched = embeds.ndim > 2
|
||||||
|
# Input to MHA must be B x T x D
|
||||||
|
if input_is_batched:
|
||||||
|
# Models using WhisperEncoder may handle batching internally.
|
||||||
|
hidden_states = torch.cat(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=0)
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
hidden_states = encoder_layer(hidden_states)
|
hidden_states = encoder_layer(hidden_states)
|
||||||
@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
encoder_outputs: list[torch.Tensor],
|
encoder_outputs: list[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert len(encoder_outputs) in (0, 1)
|
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
|
||||||
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
|
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
|
|||||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||||
# Required as part of SupportsMultiModal interface.
|
# Required as part of SupportsMultiModal interface.
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
# Split concatenated encoder outputs into one tensor per audio input
|
||||||
|
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
|
||||||
|
# The assumption is we can only process whole mm items (audios)
|
||||||
|
return enc_output.unbind(dim=0)
|
||||||
|
|
||||||
def embed_input_ids(
|
def embed_input_ids(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -341,3 +341,56 @@ def compute_mm_encoder_budget(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return encoder_compute_budget, encoder_cache_size
|
return encoder_compute_budget, encoder_cache_size
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
|
||||||
|
# use the manager for scheduling purposes. Encoder-decoder models will eventually
|
||||||
|
# utilize the cache and this class will fold into EncoderCacheManager, as
|
||||||
|
# differences with MM models shrink.
|
||||||
|
class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||||
|
def __init__(self, cache_size: int):
|
||||||
|
self.cache_size = cache_size
|
||||||
|
self.num_free_slots = cache_size
|
||||||
|
self.freed: list[str] = []
|
||||||
|
|
||||||
|
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def can_allocate(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
input_id: int,
|
||||||
|
encoder_compute_budget: int,
|
||||||
|
num_tokens_to_schedule: int,
|
||||||
|
) -> bool:
|
||||||
|
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||||
|
# Not enough compute budget
|
||||||
|
if num_tokens > encoder_compute_budget:
|
||||||
|
return False
|
||||||
|
|
||||||
|
num_tokens += num_tokens_to_schedule
|
||||||
|
# Enough free slots
|
||||||
|
return num_tokens <= self.num_free_slots
|
||||||
|
|
||||||
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
|
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||||
|
self.num_free_slots -= num_encoder_tokens
|
||||||
|
|
||||||
|
mm_hash = request.mm_features[input_id].identifier
|
||||||
|
self.freed.append(mm_hash)
|
||||||
|
|
||||||
|
def free(self, request: Request) -> None:
|
||||||
|
for input_id in range(len(request.mm_features)):
|
||||||
|
self.free_encoder_input(request, input_id)
|
||||||
|
|
||||||
|
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||||
|
return set(range(len(request.mm_features)))
|
||||||
|
|
||||||
|
def get_freed_mm_hashes(self) -> list[str]:
|
||||||
|
freed = self.freed
|
||||||
|
self.freed = []
|
||||||
|
return freed
|
||||||
|
|
||||||
|
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||||
|
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||||
|
self.num_free_slots += num_tokens
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.v1.core.encoder_cache_manager import (
|
from vllm.v1.core.encoder_cache_manager import (
|
||||||
EncoderCacheManager,
|
EncoderCacheManager,
|
||||||
|
EncoderDecoderCacheManager,
|
||||||
compute_encoder_budget,
|
compute_encoder_budget,
|
||||||
)
|
)
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||||
@ -181,7 +182,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
# NOTE: For the models without encoder (e.g., text-only models),
|
# NOTE: For the models without encoder (e.g., text-only models),
|
||||||
# the encoder cache will not be initialized because cache size is 0
|
# the encoder cache will not be initialized because cache size is 0
|
||||||
# for these models.
|
# for these models.
|
||||||
self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size)
|
self.encoder_cache_manager = (
|
||||||
|
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
|
||||||
|
if self.is_encoder_decoder
|
||||||
|
else EncoderCacheManager(cache_size=encoder_cache_size)
|
||||||
|
)
|
||||||
|
|
||||||
speculative_config = vllm_config.speculative_config
|
speculative_config = vllm_config.speculative_config
|
||||||
self.use_eagle = False
|
self.use_eagle = False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user