mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 08:30:32 +08:00
[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
890323dc1b
commit
3c7d942da8
@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
if isinstance(layer, layer_type)
|
||||
}
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: Optional[int] = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
@ -6,7 +6,6 @@ import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from math import ceil
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@ -28,7 +27,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
@ -44,9 +42,6 @@ logger = init_logger(__name__)
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
MAX_AUDIO_CLIP_SECONDS = 30
|
||||
OVERLAP_CHUNK_SECOND = 1
|
||||
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
processor = cached_get_processor(model_config.model)
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
|
||||
if hasattr(processor.feature_extractor, 'chunk_length') \
|
||||
else MAX_AUDIO_CLIP_SECONDS
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
self.task_type = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
model_config, task_type)
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self):
|
||||
return get_model_cls(self.model_config)
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
model_cls = cast(SupportsTranscription, self.model_cls)
|
||||
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang = request.language or "en"
|
||||
model_cls.validate_language(lang)
|
||||
self.model_cls.validate_language(lang)
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.model_sr)
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
chunks = [y
|
||||
] if duration < self.max_audio_clip_s else self._split_audio(
|
||||
y, int(sr))
|
||||
do_split_audio = (self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (chunk, sr),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
model_cls.get_decoder_prompt(lang, self.task_type,
|
||||
request.prompt)
|
||||
}
|
||||
prompts.append(cast(PromptType, prompt))
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
language=lang,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
async def _create_speech_to_text(
|
||||
@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
prompts[0]['decoder_prompt'], # type: ignore
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token
|
||||
# is not specified.
|
||||
num_prompt_tokens = max(
|
||||
len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(
|
||||
audio_duration_s * self.model_sr / self.hop_length)
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config,
|
||||
self.model_config):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.max_audio_clip_s
|
||||
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
for i in range(0,
|
||||
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
||||
MIN_ENERGY_WINDOW_SIZE):
|
||||
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
||||
min_energy_window = self.asr_config.min_energy_split_window_size
|
||||
assert min_energy_window is not None
|
||||
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||
window = segment[i:i + min_energy_window]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
|
||||
@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@ -692,9 +695,13 @@ class SupportsTranscription(Protocol):
|
||||
supports_transcription: ClassVar[Literal[True]] = True
|
||||
|
||||
@classmethod
|
||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||
prompt: str) -> str:
|
||||
"""Get the decoder prompt for the ASR model."""
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
The model has control over the construction, as long as it
|
||||
returns a valid PromptType."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@ -702,6 +709,25 @@ class SupportsTranscription(Protocol):
|
||||
"""Check if the model supports a specific ISO639_1 language."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig,
|
||||
task_type: Literal["transcribe",
|
||||
"translate"]) -> SpeechToTextConfig:
|
||||
"""Get the speech to text config for the ASR model."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
"""
|
||||
Map from audio duration to number of audio tokens produced by the ASR
|
||||
model, without running a forward pass.
|
||||
This is used for estimating the amount of processing for this audio.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def supports_transcription(
|
||||
|
||||
@ -3,8 +3,9 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, TypedDict, Union
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription, SupportsV0Only)
|
||||
@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||
prompt: str) -> str:
|
||||
return ((f"<|prev|>{prompt}" if prompt else "") +
|
||||
f"<|startoftranscript|><|{language}|>" +
|
||||
f"<|{task_type}|><|notimestamps|>")
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
((f"<|prev|>{request_prompt}" if request_prompt else "") +
|
||||
f"<|startoftranscript|><|{language}|>" +
|
||||
f"<|{task_type}|><|notimestamps|>")
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
||||
sample_rate=processor.feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
hop_length = processor.feature_extractor.hop_length
|
||||
assert hop_length is not None
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
return math.ceil(audio_duration_s * stt_config.sample_rate /
|
||||
hop_length)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user