[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-07-12 06:33:26 +02:00 committed by GitHub
parent 890323dc1b
commit 3c7d942da8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 60 deletions

View File

@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None

View File

@ -6,7 +6,6 @@ import math
import time
from collections.abc import AsyncGenerator
from functools import cached_property
from math import ceil
from typing import Callable, Literal, Optional, TypeVar, Union, cast
import numpy as np
@ -28,7 +27,6 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models import SupportsTranscription
from vllm.outputs import RequestOutput
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import PlaceholderModule
try:
@ -44,9 +42,6 @@ logger = init_logger(__name__)
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
MAX_AUDIO_CLIP_SECONDS = 30
OVERLAP_CHUNK_SECOND = 1
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
class OpenAISpeechToText(OpenAIServing):
@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing):
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
processor = cached_get_processor(model_config.model)
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
if hasattr(processor.feature_extractor, 'chunk_length') \
else MAX_AUDIO_CLIP_SECONDS
self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length
self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
model_config, task_type)
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params)
@cached_property
def model_cls(self):
return get_model_cls(self.model_config)
def model_cls(self) -> type[SupportsTranscription]:
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
model_cls = cast(SupportsTranscription, self.model_cls)
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang = request.language or "en"
model_cls.validate_language(lang)
self.model_cls.validate_language(lang)
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
raise ValueError("Maximum file size exceeded.")
@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing):
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.model_sr)
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
chunks = [y
] if duration < self.max_audio_clip_s else self._split_audio(
y, int(sr))
do_split_audio = (self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (chunk, sr),
},
},
"decoder_prompt":
model_cls.get_decoder_prompt(lang, self.task_type,
request.prompt)
}
prompts.append(cast(PromptType, prompt))
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
language=lang,
task_type=self.task_type,
request_prompt=request.prompt)
prompts.append(prompt)
return prompts, duration
async def _create_speech_to_text(
@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
self._log_inputs(
request_id,
prompts[0]['decoder_prompt'], # type: ignore
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)
@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing):
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token
# is not specified.
num_prompt_tokens = max(
len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
num_prompt_tokens += ceil(
audio_duration_s * self.model_sr / self.hop_length)
num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config,
self.model_config):
num_prompt_tokens += audio_tokens
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing):
def _split_audio(self, audio_data: np.ndarray,
sample_rate: int) -> list[np.ndarray]:
chunk_size = sample_rate * self.max_audio_clip_s
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = []
i = 0
while i < audio_data.shape[-1]:
@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing):
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
for i in range(0,
len(segment) - MIN_ENERGY_WINDOW_SIZE,
MIN_ENERGY_WINDOW_SIZE):
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
min_energy_window = self.asr_config.min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i:i + min_energy_window]
energy = (window**2).mean()**0.5
if energy < min_energy:
quietest_idx = i + start_idx

View File

@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import numpy as np
import torch
from torch import Tensor
from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -692,9 +695,13 @@ class SupportsTranscription(Protocol):
supports_transcription: ClassVar[Literal[True]] = True
@classmethod
def get_decoder_prompt(cls, language: str, task_type: str,
prompt: str) -> str:
"""Get the decoder prompt for the ASR model."""
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""
...
@classmethod
@ -702,6 +709,25 @@ class SupportsTranscription(Protocol):
"""Check if the model supports a specific ISO639_1 language."""
...
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig,
task_type: Literal["transcribe",
"translate"]) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
...
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig) -> Optional[int]:
"""
Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass.
This is used for estimating the amount of processing for this audio.
"""
return None
@overload
def supports_transcription(

View File

@ -3,8 +3,9 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, TypedDict, Union
from typing import Optional, TypedDict, Union, cast
import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod
def get_decoder_prompt(cls, language: str, task_type: str,
prompt: str) -> str:
return ((f"<|prev|>{prompt}" if prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
"prompt": "",
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt":
((f"<|prev|>{request_prompt}" if request_prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
}
return cast(PromptType, prompt)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
raise ValueError("Only audio modality is supported")
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
processor = cached_get_processor(model_config.model)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
sample_rate=processor.feature_extractor.sampling_rate,
)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig) -> Optional[int]:
processor = cached_get_processor(model_config.model)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
return math.ceil(audio_duration_s * stt_config.sample_rate /
hop_length)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config