[CI] Speed up Whisper tests by reusing server (#22859)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-15 16:56:31 -04:00 committed by GitHub
parent a344a1a7da
commit 8a87cd27d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 260 additions and 288 deletions

View File

@ -4,19 +4,20 @@
# imports for guided decoding tests
import io
import json
from unittest.mock import patch
import librosa
import numpy as np
import openai
import pytest
import pytest_asyncio
import soundfile as sf
from openai._base_client import AsyncAPIClient
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
@ -37,6 +38,18 @@ def winning_call():
yield f
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
@ -60,54 +73,11 @@ async def test_basic_audio(mary_had_lamb, model_name):
assert "Mary had a little lamb," in out
@pytest.mark.asyncio
async def test_bad_requests(mary_had_lamb):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
# invalid language
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=mary_had_lamb,
language="hh",
temperature=0.0)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
async def test_long_audio_request(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process
audio = np.pad(audio, (0, 1600))
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=buffer,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
counts = out.count("Mary had a little lamb")
assert counts == 10, counts
@pytest.mark.asyncio
async def test_non_asr_model(winning_call):
# text to text model
model_name = "JackFram/llama-68m"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(model=model_name,
file=winning_call,
@ -120,157 +90,149 @@ async def test_non_asr_model(winning_call):
@pytest.mark.asyncio
async def test_completion_endpoints():
async def test_bad_requests(mary_had_lamb, client):
# invalid language
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=MODEL_NAME,
file=mary_had_lamb,
language="hh",
temperature=0.0)
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process
audio = np.pad(audio, (0, 1600))
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
transcription = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=buffer,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
counts = out.count("Mary had a little lamb")
assert counts == 10, counts
@pytest.mark.asyncio
async def test_completion_endpoints(client):
# text to text model
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res = await client.chat.completions.create(
model=model_name,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}])
err = res.error
assert err["code"] == 400
assert err[
"message"] == "The model does not support Chat Completions API"
res = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}])
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Chat Completions API"
res = await client.completions.create(model=model_name, prompt="Hello")
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
@pytest.mark.asyncio
async def test_streaming_response(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
async def test_streaming_response(winning_call, client):
transcription = ""
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res_no_stream = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
response_format="json",
language="en",
temperature=0.0)
# Unfortunately this only works when the openai client is patched
# to use streaming mode, not exposed in the transcription api.
original_post = AsyncAPIClient.post
res_no_stream = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=winning_call,
response_format="json",
language="en",
temperature=0.0)
res = await client.audio.transcriptions.create(model=MODEL_NAME,
file=winning_call,
language="en",
temperature=0.0,
stream=True,
timeout=30)
# Reconstruct from chunks and validate
async for chunk in res:
text = chunk.choices[0]['delta']['content']
transcription += text
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True),
timeout=30)
# Reconstruct from chunks and validate
async for chunk in res:
# just a chunk
text = chunk.choices[0]['delta']['content']
transcription += text
assert transcription == res_no_stream.text
assert transcription == res_no_stream.text
@pytest.mark.asyncio
async def test_stream_options(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
original_post = AsyncAPIClient.post
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True,
stream_include_usage=True,
stream_continuous_usage_stats=True),
timeout=30)
final = False
continuous = True
async for chunk in res:
if not len(chunk.choices):
# final usage sent
final = True
else:
continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous
async def test_stream_options(winning_call, client):
res = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=winning_call,
language="en",
temperature=0.0,
stream=True,
extra_body=dict(stream_include_usage=True,
stream_continuous_usage_stats=True),
timeout=30)
final = False
continuous = True
async for chunk in res:
if not len(chunk.choices):
# final usage sent
final = True
else:
continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous
@pytest.mark.asyncio
async def test_sampling_params(mary_had_lamb):
async def test_sampling_params(mary_had_lamb, client):
"""
Compare sampling with params and greedy sampling to assert results
are different when extreme sampling parameters values are picked.
"""
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
temperature=0.8,
extra_body=dict(seed=42,
repetition_penalty=1.9,
top_k=12,
top_p=0.4,
min_p=0.5,
frequency_penalty=1.8,
presence_penalty=2.0))
transcription = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
temperature=0.8,
extra_body=dict(seed=42,
repetition_penalty=1.9,
top_k=12,
top_p=0.4,
min_p=0.5,
frequency_penalty=1.8,
presence_penalty=2.0))
greedy_transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
temperature=0.0,
extra_body=dict(seed=42))
greedy_transcription = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
temperature=0.0,
extra_body=dict(seed=42))
assert greedy_transcription.text != transcription.text
assert greedy_transcription.text != transcription.text
@pytest.mark.asyncio
async def test_audio_prompt(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
async def test_audio_prompt(mary_had_lamb, client):
prompt = "This is a speech, recorded in a phonograph."
with RemoteOpenAIServer(model_name, server_args) as remote_server:
#Prompts should not omit the part of original prompt while transcribing.
prefix = "The first words I spoke in the original phonograph"
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert prefix in out
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_prompt = json.loads(transcription_wprompt)['text']
assert prefix in out_prompt
#Prompts should not omit the part of original prompt while transcribing.
prefix = "The first words I spoke in the original phonograph"
transcription = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert prefix in out
transcription_wprompt = await client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_prompt = json.loads(transcription_wprompt)['text']
assert prefix in out_prompt

View File

@ -4,18 +4,21 @@
import io
# imports for guided decoding tests
import json
from unittest.mock import patch
import httpx
import librosa
import numpy as np
import pytest
import pytest_asyncio
import soundfile as sf
from openai._base_client import AsyncAPIClient
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/whisper-small"
SERVER_ARGS = ["--enforce-eager"]
@pytest.fixture
def foscolo():
@ -25,50 +28,23 @@ def foscolo():
yield f
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
async def test_basic_audio(foscolo):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
translation = await client.audio.translations.create(
model=model_name,
file=foscolo,
response_format="text",
# TODO remove once language detection is implemented
extra_body=dict(language="it"),
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert "greek sea" in out
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_audio_prompt(foscolo):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
# Condition whisper on starting text
prompt = "Nor have I ever"
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.translations.create(
model=model_name,
file=foscolo,
prompt=prompt,
extra_body=dict(language="it"),
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "Nor will I ever touch the sacred" not in out
assert prompt not in out
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_non_asr_model(foscolo):
# text to text model
model_name = "JackFram/llama-68m"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
client = remote_server.get_async_client()
res = await client.audio.translations.create(model=model_name,
file=foscolo,
@ -78,81 +54,117 @@ async def test_non_asr_model(foscolo):
assert err["message"] == "The model does not support Translations API"
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
async def test_streaming_response(foscolo):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
async def test_basic_audio(foscolo, client):
translation = await client.audio.translations.create(
model=MODEL_NAME,
file=foscolo,
response_format="text",
# TODO remove once language detection is implemented
extra_body=dict(language="it"),
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert "greek sea" in out
@pytest.mark.asyncio
async def test_audio_prompt(foscolo, client):
# Condition whisper on starting text
prompt = "Nor have I ever"
transcription = await client.audio.translations.create(
model=MODEL_NAME,
file=foscolo,
prompt=prompt,
extra_body=dict(language="it"),
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "Nor will I ever touch the sacred" not in out
assert prompt not in out
@pytest.mark.asyncio
async def test_streaming_response(foscolo, client, server):
translation = ""
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res_no_stream = await client.audio.translations.create(
model=model_name,
file=foscolo,
response_format="json",
extra_body=dict(language="it"),
temperature=0.0)
# Unfortunately this only works when the openai client is patched
# to use streaming mode, not exposed in the translation api.
original_post = AsyncAPIClient.post
res_no_stream = await client.audio.translations.create(
model=MODEL_NAME,
file=foscolo,
response_format="json",
extra_body=dict(language="it"),
temperature=0.0)
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"language": "it",
"stream": True,
"temperature": 0.0,
}
foscolo.seek(0)
async with httpx.AsyncClient() as http_client:
files = {"file": foscolo}
async with http_client.stream("POST",
url,
headers=headers,
data=data,
files=files) as response:
async for line in response.aiter_lines():
if not line:
continue
if line.startswith("data: "):
line = line[len("data: "):]
if line.strip() == "[DONE]":
break
chunk = json.loads(line)
text = chunk["choices"][0].get("delta", {}).get("content")
translation += text or ""
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.translations.create(model=model_name,
file=foscolo,
temperature=0.0,
extra_body=dict(
stream=True,
language="it"))
# Reconstruct from chunks and validate
async for chunk in res:
# just a chunk
text = chunk.choices[0]['delta']['content']
translation += text
assert translation == res_no_stream.text
assert translation == res_no_stream.text
@pytest.mark.asyncio
async def test_stream_options(foscolo):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
original_post = AsyncAPIClient.post
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.translations.create(
model=model_name,
file=foscolo,
temperature=0.0,
extra_body=dict(language="it",
stream=True,
stream_include_usage=True,
stream_continuous_usage_stats=True))
final = False
continuous = True
async for chunk in res:
if not len(chunk.choices):
async def test_stream_options(foscolo, client, server):
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"language": "it",
"stream": True,
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
"temperature": 0.0,
}
foscolo.seek(0)
final = False
continuous = True
async with httpx.AsyncClient() as http_client:
files = {"file": foscolo}
async with http_client.stream("POST",
url,
headers=headers,
data=data,
files=files) as response:
async for line in response.aiter_lines():
if not line:
continue
if line.startswith("data: "):
line = line[len("data: "):]
if line.strip() == "[DONE]":
break
chunk = json.loads(line)
choices = chunk.get("choices", [])
if not choices:
# final usage sent
final = True
else:
continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous
continuous = continuous and ("usage" in chunk)
assert final and continuous
@pytest.mark.asyncio
async def test_long_audio_request(foscolo):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
async def test_long_audio_request(foscolo, client):
foscolo.seek(0)
audio, sr = librosa.load(foscolo)
repeated_audio = np.tile(audio, 2)
@ -160,13 +172,11 @@ async def test_long_audio_request(foscolo):
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
translation = await client.audio.translations.create(
model=model_name,
file=buffer,
extra_body=dict(language="it"),
response_format="text",
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert out.count("greek sea") == 2
translation = await client.audio.translations.create(
model=MODEL_NAME,
file=buffer,
extra_body=dict(language="it"),
response_format="text",
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert out.count("greek sea") == 2