fix: add warmup for audio preprocessing (#30706)

Signed-off-by: Nathan Price <nathan@abridge.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nathan Price 2025-12-18 00:12:29 -06:00 committed by GitHub
parent ec965569d9
commit fc2ae6d617
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,7 +35,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRe
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsTranscription
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
@ -112,6 +112,131 @@ class OpenAISpeechToText(OpenAIServing):
self.default_sampling_params,
)
# Warm up audio preprocessing to avoid first-request latency
self._warmup_audio_preprocessing()
# Warm up input processor with dummy audio
self._warmup_input_processor()
def _warmup_audio_preprocessing(self) -> None:
"""Warm up audio processing libraries to avoid first-request latency.
The first call to librosa functions (load, get_duration, mel-spectrogram)
triggers JIT compilation and library initialization which can take ~7s.
This method warms up these operations during server initialization.
"""
# Skip warmup if librosa is not installed (optional dependency)
if isinstance(librosa, PlaceholderModule):
return
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
try:
warmup_start = time.perf_counter()
logger.info("Warming up audio preprocessing libraries...")
# Create a minimal dummy audio (1 second of silence at target sample rate)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Warm up librosa.load by using librosa functions on the dummy data
# This initializes FFTW, numba JIT, and other audio processing libraries
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
# Warm up mel-spectrogram computation with model-specific parameters
from vllm.transformers_utils.processor import (
cached_processor_from_config,
)
processor = cached_processor_from_config(self.model_config)
feature_extractor = None
if hasattr(processor, "feature_extractor"):
feature_extractor = processor.feature_extractor
elif hasattr(processor, "audio_processor"):
# For models like GraniteSpeech that use audio_processor
audio_proc = processor.audio_processor
if hasattr(audio_proc, "feature_extractor"):
feature_extractor = audio_proc.feature_extractor
# If audio_processor doesn't have feature_extractor,
# skip mel-spectrogram warmup for these models
if feature_extractor is not None:
_ = librosa.feature.melspectrogram(
y=dummy_audio,
sr=self.asr_config.sample_rate,
n_mels=getattr(feature_extractor, "n_mels", 128),
n_fft=getattr(feature_extractor, "n_fft", 400),
hop_length=getattr(feature_extractor, "hop_length", 160),
)
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log exception and continue
logger.exception(
"Audio preprocessing warmup failed (non-fatal): %s. "
"First request may experience higher latency.",
)
def _warmup_input_processor(self) -> None:
"""Warm up input processor with dummy audio to avoid first-request latency.
The first call to input_processor.process_inputs() with multimodal audio
triggers multimodal processing initialization which can take ~2.5s.
This method processes a dummy audio request to warm up the pipeline.
"""
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
# Only warm up if model supports transcription methods
if not hasattr(self.model_cls, "get_generation_prompt"):
return
try:
from vllm.sampling_params import SamplingParams
warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...")
# Create minimal dummy audio (1 second of silence)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Use the same method that _preprocess_speech_to_text uses
# to create the prompt
dummy_prompt = self.model_cls.get_generation_prompt(
audio=dummy_audio,
stt_config=self.asr_config,
model_config=self.model_config,
language="en",
task_type=self.task_type,
request_prompt="",
to_language=None,
)
# Create minimal sampling params
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
)
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
_ = self.input_processor.process_inputs(
request_id="warmup",
prompt=dummy_prompt,
params=dummy_params,
)
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log warning and continue
logger.exception(
"Input processor warmup failed (non-fatal): %s. "
"First request may experience higher latency."
)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls