diff --git a/vllm/config.py b/vllm/config.py index 9938dcf07a7fc..cfd7b9e336704 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig, vllm_config.compilation_config.static_forward_context.items() if isinstance(layer, layer_type) } + + +@config +@dataclass +class SpeechToTextConfig: + """Configuration for speech-to-text models.""" + + sample_rate: float = 16_000 + """Sample rate (Hz) to resample input audio to. Most speech models expect + 16kHz audio input. The input audio will be automatically resampled to this + rate before processing.""" + + max_audio_clip_s: int = 30 + """Maximum duration in seconds for a single audio clip without chunking. + Audio longer than this will be split into smaller chunks if + `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" + + overlap_chunk_second: int = 1 + """Overlap duration in seconds between consecutive audio chunks when + splitting long audio. This helps maintain context across chunk boundaries + and improves transcription quality at split points.""" + + min_energy_split_window_size: Optional[int] = 1600 + """Window size in samples for finding low-energy (quiet) regions to split + audio chunks. The algorithm looks for the quietest moment within this + window to minimize cutting through speech. Default 1600 samples ≈ 100ms + at 16kHz. If None, no chunking will be done.""" + + @property + def allow_audio_chunking(self) -> bool: + return self.min_energy_split_window_size is not None \ No newline at end of file diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 0ab029e5305bd..c70355b2ae43a 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -6,7 +6,6 @@ import math import time from collections.abc import AsyncGenerator from functools import cached_property -from math import ceil from typing import Callable, Literal, Optional, TypeVar, Union, cast import numpy as np @@ -28,7 +27,6 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.models import SupportsTranscription from vllm.outputs import RequestOutput -from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import PlaceholderModule try: @@ -44,9 +42,6 @@ logger = init_logger(__name__) # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable MAX_AUDIO_CLIP_FILESIZE_MB = 25 -MAX_AUDIO_CLIP_SECONDS = 30 -OVERLAP_CHUNK_SECOND = 1 -MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio class OpenAISpeechToText(OpenAIServing): @@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing): self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) - processor = cached_get_processor(model_config.model) - self.max_audio_clip_s = processor.feature_extractor.chunk_length \ - if hasattr(processor.feature_extractor, 'chunk_length') \ - else MAX_AUDIO_CLIP_SECONDS - self.model_sr = processor.feature_extractor.sampling_rate - self.hop_length = processor.feature_extractor.hop_length self.task_type = task_type + self.asr_config = self.model_cls.get_speech_to_text_config( + model_config, task_type) + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", self.default_sampling_params) @cached_property - def model_cls(self): - return get_model_cls(self.model_config) + def model_cls(self) -> type[SupportsTranscription]: + model_cls = get_model_cls(self.model_config) + return cast(type[SupportsTranscription], model_cls) async def _preprocess_speech_to_text( self, request: SpeechToTextRequest, audio_data: bytes, ) -> tuple[list[PromptType], float]: - model_cls = cast(SupportsTranscription, self.model_cls) - # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 lang = request.language or "en" - model_cls.validate_language(lang) + self.model_cls.validate_language(lang) if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: raise ValueError("Maximum file size exceeded.") @@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing): with io.BytesIO(audio_data) as bytes_: # NOTE resample to model SR here for efficiency. This is also a # pre-requisite for chunking, as it assumes Whisper SR. - y, sr = librosa.load(bytes_, sr=self.model_sr) + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) duration = librosa.get_duration(y=y, sr=sr) - chunks = [y - ] if duration < self.max_audio_clip_s else self._split_audio( - y, int(sr)) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) prompts = [] for chunk in chunks: - prompt = { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": (chunk, sr), - }, - }, - "decoder_prompt": - model_cls.get_decoder_prompt(lang, self.task_type, - request.prompt) - } - prompts.append(cast(PromptType, prompt)) + # The model has control over the construction, as long as it + # returns a valid PromptType. + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + language=lang, + task_type=self.task_type, + request_prompt=request.prompt) + prompts.append(prompt) return prompts, duration async def _create_speech_to_text( @@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing): self._log_inputs( request_id, - prompts[0]['decoder_prompt'], # type: ignore + # It will not display special tokens like <|startoftranscript|> + request.prompt, params=sampling_params, lora_request=None, prompt_adapter_request=None) @@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing): async for res in result_generator: # On first result. if res.prompt_token_ids is not None: - # Do not account the 4-tokens `<|startoftranscript|>..` - # Could be negative when language token - # is not specified. - num_prompt_tokens = max( - len(res.prompt_token_ids) - 4, 0) - # NOTE(NickLucche) user can't pass encoder - # prompts directly at least not to Whisper. - # One indicator of the encoder amount of processing - # is the log-mel spectogram length. - num_prompt_tokens += ceil( - audio_duration_s * self.model_sr / self.hop_length) + num_prompt_tokens = len(res.prompt_token_ids) + if audio_tokens := self.model_cls.get_num_audio_tokens( + audio_duration_s, self.asr_config, + self.model_config): + num_prompt_tokens += audio_tokens # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST @@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing): def _split_audio(self, audio_data: np.ndarray, sample_rate: int) -> list[np.ndarray]: - chunk_size = sample_rate * self.max_audio_clip_s - overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunk_size = sample_rate * self.asr_config.max_audio_clip_s + overlap_size = sample_rate * self.asr_config.overlap_chunk_second chunks = [] i = 0 while i < audio_data.shape[-1]: @@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing): # Calculate RMS energy in small windows min_energy = math.inf quietest_idx = 0 - for i in range(0, - len(segment) - MIN_ENERGY_WINDOW_SIZE, - MIN_ENERGY_WINDOW_SIZE): - window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + min_energy_window = self.asr_config.min_energy_split_window_size + assert min_energy_window is not None + for i in range(0, len(segment) - min_energy_window, min_energy_window): + window = segment[i:i + min_energy_window] energy = (window**2).mean()**0.5 if energy < min_energy: quietest_idx = i + start_idx diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 503147367106c..99669a233634b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) +import numpy as np import torch from torch import Tensor from typing_extensions import Self, TypeIs +from vllm.config import ModelConfig, SpeechToTextConfig from vllm.inputs import TokensPrompt +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -692,9 +695,13 @@ class SupportsTranscription(Protocol): supports_transcription: ClassVar[Literal[True]] = True @classmethod - def get_decoder_prompt(cls, language: str, task_type: str, - prompt: str) -> str: - """Get the decoder prompt for the ASR model.""" + def get_generation_prompt(cls, audio: np.ndarray, + stt_config: SpeechToTextConfig, language: str, + task_type: str, + request_prompt: str) -> PromptType: + """Get the prompt for the ASR model. + The model has control over the construction, as long as it + returns a valid PromptType.""" ... @classmethod @@ -702,6 +709,25 @@ class SupportsTranscription(Protocol): """Check if the model supports a specific ISO639_1 language.""" ... + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, + task_type: Literal["transcribe", + "translate"]) -> SpeechToTextConfig: + """Get the speech to text config for the ASR model.""" + ... + + @classmethod + def get_num_audio_tokens(cls, audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig) -> Optional[int]: + """ + Map from audio duration to number of audio tokens produced by the ASR + model, without running a forward pass. + This is used for estimating the amount of processing for this audio. + """ + return None + @overload def supports_transcription( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ee1cfd7d71374..1a7982e48e4b1 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -3,8 +3,9 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, TypedDict, Union +from typing import Optional, TypedDict, Union, cast +import numpy as np import torch from torch import nn from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, @@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType -from vllm.config import CacheConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, + VllmConfig) from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.transformers_utils.processor import cached_get_processor from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, f"or {list(ISO639_1_OTHER_LANGS.values())}") @classmethod - def get_decoder_prompt(cls, language: str, task_type: str, - prompt: str) -> str: - return ((f"<|prev|>{prompt}" if prompt else "") + - f"<|startoftranscript|><|{language}|>" + - f"<|{task_type}|><|notimestamps|>") + def get_generation_prompt(cls, audio: np.ndarray, + stt_config: SpeechToTextConfig, language: str, + task_type: str, + request_prompt: str) -> PromptType: + prompt = { + "encoder_prompt": { + # Whisper does not support encoder prompt. + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": + ((f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>") + } + return cast(PromptType, prompt) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, raise ValueError("Only audio modality is supported") + @classmethod + def get_speech_to_text_config(cls, model_config: ModelConfig, + task_type: str) -> SpeechToTextConfig: + processor = cached_get_processor(model_config.model) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens(cls, audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig) -> Optional[int]: + processor = cached_get_processor(model_config.model) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + return math.ceil(audio_duration_s * stt_config.sample_rate / + hop_length) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config