mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-17 20:44:30 +08:00
fix: add warmup for audio preprocessing (#30706)
Signed-off-by: Nathan Price <nathan@abridge.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ec965569d9
commit
fc2ae6d617
@ -35,7 +35,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRe
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
@ -112,6 +112,131 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Warm up audio preprocessing to avoid first-request latency
|
||||
self._warmup_audio_preprocessing()
|
||||
# Warm up input processor with dummy audio
|
||||
self._warmup_input_processor()
|
||||
|
||||
def _warmup_audio_preprocessing(self) -> None:
|
||||
"""Warm up audio processing libraries to avoid first-request latency.
|
||||
|
||||
The first call to librosa functions (load, get_duration, mel-spectrogram)
|
||||
triggers JIT compilation and library initialization which can take ~7s.
|
||||
This method warms up these operations during server initialization.
|
||||
"""
|
||||
# Skip warmup if librosa is not installed (optional dependency)
|
||||
if isinstance(librosa, PlaceholderModule):
|
||||
return
|
||||
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up audio preprocessing libraries...")
|
||||
|
||||
# Create a minimal dummy audio (1 second of silence at target sample rate)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Warm up librosa.load by using librosa functions on the dummy data
|
||||
# This initializes FFTW, numba JIT, and other audio processing libraries
|
||||
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
|
||||
|
||||
# Warm up mel-spectrogram computation with model-specific parameters
|
||||
from vllm.transformers_utils.processor import (
|
||||
cached_processor_from_config,
|
||||
)
|
||||
|
||||
processor = cached_processor_from_config(self.model_config)
|
||||
feature_extractor = None
|
||||
if hasattr(processor, "feature_extractor"):
|
||||
feature_extractor = processor.feature_extractor
|
||||
elif hasattr(processor, "audio_processor"):
|
||||
# For models like GraniteSpeech that use audio_processor
|
||||
audio_proc = processor.audio_processor
|
||||
if hasattr(audio_proc, "feature_extractor"):
|
||||
feature_extractor = audio_proc.feature_extractor
|
||||
# If audio_processor doesn't have feature_extractor,
|
||||
# skip mel-spectrogram warmup for these models
|
||||
|
||||
if feature_extractor is not None:
|
||||
_ = librosa.feature.melspectrogram(
|
||||
y=dummy_audio,
|
||||
sr=self.asr_config.sample_rate,
|
||||
n_mels=getattr(feature_extractor, "n_mels", 128),
|
||||
n_fft=getattr(feature_extractor, "n_fft", 400),
|
||||
hop_length=getattr(feature_extractor, "hop_length", 160),
|
||||
)
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log exception and continue
|
||||
logger.exception(
|
||||
"Audio preprocessing warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency.",
|
||||
)
|
||||
|
||||
def _warmup_input_processor(self) -> None:
|
||||
"""Warm up input processor with dummy audio to avoid first-request latency.
|
||||
|
||||
The first call to input_processor.process_inputs() with multimodal audio
|
||||
triggers multimodal processing initialization which can take ~2.5s.
|
||||
This method processes a dummy audio request to warm up the pipeline.
|
||||
"""
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
# Only warm up if model supports transcription methods
|
||||
if not hasattr(self.model_cls, "get_generation_prompt"):
|
||||
return
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up multimodal input processor...")
|
||||
|
||||
# Create minimal dummy audio (1 second of silence)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Use the same method that _preprocess_speech_to_text uses
|
||||
# to create the prompt
|
||||
dummy_prompt = self.model_cls.get_generation_prompt(
|
||||
audio=dummy_audio,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language="en",
|
||||
task_type=self.task_type,
|
||||
request_prompt="",
|
||||
to_language=None,
|
||||
)
|
||||
|
||||
# Create minimal sampling params
|
||||
dummy_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
# This will trigger all the multimodal processing initialization
|
||||
_ = self.input_processor.process_inputs(
|
||||
request_id="warmup",
|
||||
prompt=dummy_prompt,
|
||||
params=dummy_params,
|
||||
)
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log warning and continue
|
||||
logger.exception(
|
||||
"Input processor warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user