mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 22:19:10 +08:00
[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
890323dc1b
commit
3c7d942da8
@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
|||||||
vllm_config.compilation_config.static_forward_context.items()
|
vllm_config.compilation_config.static_forward_context.items()
|
||||||
if isinstance(layer, layer_type)
|
if isinstance(layer, layer_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class SpeechToTextConfig:
|
||||||
|
"""Configuration for speech-to-text models."""
|
||||||
|
|
||||||
|
sample_rate: float = 16_000
|
||||||
|
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||||
|
16kHz audio input. The input audio will be automatically resampled to this
|
||||||
|
rate before processing."""
|
||||||
|
|
||||||
|
max_audio_clip_s: int = 30
|
||||||
|
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||||
|
Audio longer than this will be split into smaller chunks if
|
||||||
|
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||||
|
|
||||||
|
overlap_chunk_second: int = 1
|
||||||
|
"""Overlap duration in seconds between consecutive audio chunks when
|
||||||
|
splitting long audio. This helps maintain context across chunk boundaries
|
||||||
|
and improves transcription quality at split points."""
|
||||||
|
|
||||||
|
min_energy_split_window_size: Optional[int] = 1600
|
||||||
|
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||||
|
audio chunks. The algorithm looks for the quietest moment within this
|
||||||
|
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||||
|
at 16kHz. If None, no chunking will be done."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_audio_chunking(self) -> bool:
|
||||||
|
return self.min_energy_split_window_size is not None
|
||||||
@ -6,7 +6,6 @@ import math
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from math import ceil
|
|
||||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -28,7 +27,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.model_loader import get_model_cls
|
from vllm.model_executor.model_loader import get_model_cls
|
||||||
from vllm.model_executor.models import SupportsTranscription
|
from vllm.model_executor.models import SupportsTranscription
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
|
||||||
from vllm.utils import PlaceholderModule
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -44,9 +42,6 @@ logger = init_logger(__name__)
|
|||||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||||
# TODO configurable
|
# TODO configurable
|
||||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||||
MAX_AUDIO_CLIP_SECONDS = 30
|
|
||||||
OVERLAP_CHUNK_SECOND = 1
|
|
||||||
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechToText(OpenAIServing):
|
class OpenAISpeechToText(OpenAIServing):
|
||||||
@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
|
|
||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
self.model_config.get_diff_sampling_param())
|
self.model_config.get_diff_sampling_param())
|
||||||
processor = cached_get_processor(model_config.model)
|
|
||||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
|
|
||||||
if hasattr(processor.feature_extractor, 'chunk_length') \
|
|
||||||
else MAX_AUDIO_CLIP_SECONDS
|
|
||||||
self.model_sr = processor.feature_extractor.sampling_rate
|
|
||||||
self.hop_length = processor.feature_extractor.hop_length
|
|
||||||
self.task_type = task_type
|
self.task_type = task_type
|
||||||
|
|
||||||
|
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||||
|
model_config, task_type)
|
||||||
|
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overwriting default completion sampling param with: %s",
|
"Overwriting default completion sampling param with: %s",
|
||||||
self.default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model_cls(self):
|
def model_cls(self) -> type[SupportsTranscription]:
|
||||||
return get_model_cls(self.model_config)
|
model_cls = get_model_cls(self.model_config)
|
||||||
|
return cast(type[SupportsTranscription], model_cls)
|
||||||
|
|
||||||
async def _preprocess_speech_to_text(
|
async def _preprocess_speech_to_text(
|
||||||
self,
|
self,
|
||||||
request: SpeechToTextRequest,
|
request: SpeechToTextRequest,
|
||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
) -> tuple[list[PromptType], float]:
|
) -> tuple[list[PromptType], float]:
|
||||||
model_cls = cast(SupportsTranscription, self.model_cls)
|
|
||||||
|
|
||||||
# Validate request
|
# Validate request
|
||||||
# TODO language should be optional and can be guessed.
|
# TODO language should be optional and can be guessed.
|
||||||
# For now we default to en. See
|
# For now we default to en. See
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||||
lang = request.language or "en"
|
lang = request.language or "en"
|
||||||
model_cls.validate_language(lang)
|
self.model_cls.validate_language(lang)
|
||||||
|
|
||||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||||
raise ValueError("Maximum file size exceeded.")
|
raise ValueError("Maximum file size exceeded.")
|
||||||
@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
with io.BytesIO(audio_data) as bytes_:
|
with io.BytesIO(audio_data) as bytes_:
|
||||||
# NOTE resample to model SR here for efficiency. This is also a
|
# NOTE resample to model SR here for efficiency. This is also a
|
||||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||||
y, sr = librosa.load(bytes_, sr=self.model_sr)
|
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||||
|
|
||||||
duration = librosa.get_duration(y=y, sr=sr)
|
duration = librosa.get_duration(y=y, sr=sr)
|
||||||
chunks = [y
|
do_split_audio = (self.asr_config.allow_audio_chunking
|
||||||
] if duration < self.max_audio_clip_s else self._split_audio(
|
and duration > self.asr_config.max_audio_clip_s)
|
||||||
y, int(sr))
|
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||||
prompts = []
|
prompts = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
prompt = {
|
# The model has control over the construction, as long as it
|
||||||
"encoder_prompt": {
|
# returns a valid PromptType.
|
||||||
"prompt": "",
|
prompt = self.model_cls.get_generation_prompt(
|
||||||
"multi_modal_data": {
|
audio=chunk,
|
||||||
"audio": (chunk, sr),
|
stt_config=self.asr_config,
|
||||||
},
|
language=lang,
|
||||||
},
|
task_type=self.task_type,
|
||||||
"decoder_prompt":
|
request_prompt=request.prompt)
|
||||||
model_cls.get_decoder_prompt(lang, self.task_type,
|
prompts.append(prompt)
|
||||||
request.prompt)
|
|
||||||
}
|
|
||||||
prompts.append(cast(PromptType, prompt))
|
|
||||||
return prompts, duration
|
return prompts, duration
|
||||||
|
|
||||||
async def _create_speech_to_text(
|
async def _create_speech_to_text(
|
||||||
@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
|
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
prompts[0]['decoder_prompt'], # type: ignore
|
# It will not display special tokens like <|startoftranscript|>
|
||||||
|
request.prompt,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
prompt_adapter_request=None)
|
prompt_adapter_request=None)
|
||||||
@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
# On first result.
|
# On first result.
|
||||||
if res.prompt_token_ids is not None:
|
if res.prompt_token_ids is not None:
|
||||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
num_prompt_tokens = len(res.prompt_token_ids)
|
||||||
# Could be negative when language token
|
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||||
# is not specified.
|
audio_duration_s, self.asr_config,
|
||||||
num_prompt_tokens = max(
|
self.model_config):
|
||||||
len(res.prompt_token_ids) - 4, 0)
|
num_prompt_tokens += audio_tokens
|
||||||
# NOTE(NickLucche) user can't pass encoder
|
|
||||||
# prompts directly at least not to Whisper.
|
|
||||||
# One indicator of the encoder amount of processing
|
|
||||||
# is the log-mel spectogram length.
|
|
||||||
num_prompt_tokens += ceil(
|
|
||||||
audio_duration_s * self.model_sr / self.hop_length)
|
|
||||||
|
|
||||||
# We need to do it here, because if there are exceptions in
|
# We need to do it here, because if there are exceptions in
|
||||||
# the result_generator, it needs to be sent as the FIRST
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
|
|
||||||
def _split_audio(self, audio_data: np.ndarray,
|
def _split_audio(self, audio_data: np.ndarray,
|
||||||
sample_rate: int) -> list[np.ndarray]:
|
sample_rate: int) -> list[np.ndarray]:
|
||||||
chunk_size = sample_rate * self.max_audio_clip_s
|
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||||
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||||
chunks = []
|
chunks = []
|
||||||
i = 0
|
i = 0
|
||||||
while i < audio_data.shape[-1]:
|
while i < audio_data.shape[-1]:
|
||||||
@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
# Calculate RMS energy in small windows
|
# Calculate RMS energy in small windows
|
||||||
min_energy = math.inf
|
min_energy = math.inf
|
||||||
quietest_idx = 0
|
quietest_idx = 0
|
||||||
for i in range(0,
|
min_energy_window = self.asr_config.min_energy_split_window_size
|
||||||
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
assert min_energy_window is not None
|
||||||
MIN_ENERGY_WINDOW_SIZE):
|
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||||
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
window = segment[i:i + min_energy_window]
|
||||||
energy = (window**2).mean()**0.5
|
energy = (window**2).mean()**0.5
|
||||||
if energy < min_energy:
|
if energy < min_energy:
|
||||||
quietest_idx = i + start_idx
|
quietest_idx = i + start_idx
|
||||||
|
|||||||
@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence
|
|||||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||||
Union, overload, runtime_checkable)
|
Union, overload, runtime_checkable)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing_extensions import Self, TypeIs
|
from typing_extensions import Self, TypeIs
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -692,9 +695,13 @@ class SupportsTranscription(Protocol):
|
|||||||
supports_transcription: ClassVar[Literal[True]] = True
|
supports_transcription: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
prompt: str) -> str:
|
stt_config: SpeechToTextConfig, language: str,
|
||||||
"""Get the decoder prompt for the ASR model."""
|
task_type: str,
|
||||||
|
request_prompt: str) -> PromptType:
|
||||||
|
"""Get the prompt for the ASR model.
|
||||||
|
The model has control over the construction, as long as it
|
||||||
|
returns a valid PromptType."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -702,6 +709,25 @@ class SupportsTranscription(Protocol):
|
|||||||
"""Check if the model supports a specific ISO639_1 language."""
|
"""Check if the model supports a specific ISO639_1 language."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(
|
||||||
|
cls, model_config: ModelConfig,
|
||||||
|
task_type: Literal["transcribe",
|
||||||
|
"translate"]) -> SpeechToTextConfig:
|
||||||
|
"""Get the speech to text config for the ASR model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
model_config: ModelConfig) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Map from audio duration to number of audio tokens produced by the ASR
|
||||||
|
model, without running a forward pass.
|
||||||
|
This is used for estimating the amount of processing for this audio.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def supports_transcription(
|
def supports_transcription(
|
||||||
|
|||||||
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Optional, TypedDict, Union
|
from typing import Optional, TypedDict, Union, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||||
@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
|||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||||
|
VllmConfig)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
|||||||
EncDecMultiModalProcessor,
|
EncDecMultiModalProcessor,
|
||||||
PromptReplacement, PromptUpdate)
|
PromptReplacement, PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
SupportsTranscription, SupportsV0Only)
|
SupportsTranscription, SupportsV0Only)
|
||||||
@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
prompt: str) -> str:
|
stt_config: SpeechToTextConfig, language: str,
|
||||||
return ((f"<|prev|>{prompt}" if prompt else "") +
|
task_type: str,
|
||||||
f"<|startoftranscript|><|{language}|>" +
|
request_prompt: str) -> PromptType:
|
||||||
f"<|{task_type}|><|notimestamps|>")
|
prompt = {
|
||||||
|
"encoder_prompt": {
|
||||||
|
# Whisper does not support encoder prompt.
|
||||||
|
"prompt": "",
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": (audio, stt_config.sample_rate),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"decoder_prompt":
|
||||||
|
((f"<|prev|>{request_prompt}" if request_prompt else "") +
|
||||||
|
f"<|startoftranscript|><|{language}|>" +
|
||||||
|
f"<|{task_type}|><|notimestamps|>")
|
||||||
|
}
|
||||||
|
return cast(PromptType, prompt)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
|
|
||||||
raise ValueError("Only audio modality is supported")
|
raise ValueError("Only audio modality is supported")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||||
|
task_type: str) -> SpeechToTextConfig:
|
||||||
|
processor = cached_get_processor(model_config.model)
|
||||||
|
|
||||||
|
return SpeechToTextConfig(
|
||||||
|
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
||||||
|
sample_rate=processor.feature_extractor.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
model_config: ModelConfig) -> Optional[int]:
|
||||||
|
processor = cached_get_processor(model_config.model)
|
||||||
|
hop_length = processor.feature_extractor.hop_length
|
||||||
|
assert hop_length is not None
|
||||||
|
# NOTE(NickLucche) user can't pass encoder
|
||||||
|
# prompts directly at least not to Whisper.
|
||||||
|
# One indicator of the encoder amount of processing
|
||||||
|
# is the log-mel spectogram length.
|
||||||
|
return math.ceil(audio_duration_s * stt_config.sample_rate /
|
||||||
|
hop_length)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user