mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[CI] Speed up Whisper tests by reusing server (#22859)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a344a1a7da
commit
8a87cd27d9
@ -4,19 +4,20 @@
|
|||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from openai._base_client import AsyncAPIClient
|
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "openai/whisper-large-v3-turbo"
|
||||||
|
SERVER_ARGS = ["--enforce-eager"]
|
||||||
MISTRAL_FORMAT_ARGS = [
|
MISTRAL_FORMAT_ARGS = [
|
||||||
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||||
"--load_format", "mistral"
|
"--load_format", "mistral"
|
||||||
@ -37,6 +38,18 @@ def winning_call():
|
|||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server):
|
||||||
|
async with server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
@ -60,54 +73,11 @@ async def test_basic_audio(mary_had_lamb, model_name):
|
|||||||
assert "Mary had a little lamb," in out
|
assert "Mary had a little lamb," in out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_bad_requests(mary_had_lamb):
|
|
||||||
model_name = "openai/whisper-small"
|
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
|
|
||||||
# invalid language
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
|
||||||
await client.audio.transcriptions.create(model=model_name,
|
|
||||||
file=mary_had_lamb,
|
|
||||||
language="hh",
|
|
||||||
temperature=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
|
|
||||||
async def test_long_audio_request(mary_had_lamb, model_name):
|
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
|
|
||||||
mary_had_lamb.seek(0)
|
|
||||||
audio, sr = librosa.load(mary_had_lamb)
|
|
||||||
# Add small silence after each audio for repeatability in the split process
|
|
||||||
audio = np.pad(audio, (0, 1600))
|
|
||||||
repeated_audio = np.tile(audio, 10)
|
|
||||||
# Repeated audio to buffer
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
|
||||||
buffer.seek(0)
|
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
transcription = await client.audio.transcriptions.create(
|
|
||||||
model=model_name,
|
|
||||||
file=buffer,
|
|
||||||
language="en",
|
|
||||||
response_format="text",
|
|
||||||
temperature=0.0)
|
|
||||||
out = json.loads(transcription)['text']
|
|
||||||
counts = out.count("Mary had a little lamb")
|
|
||||||
assert counts == 10, counts
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_asr_model(winning_call):
|
async def test_non_asr_model(winning_call):
|
||||||
# text to text model
|
# text to text model
|
||||||
model_name = "JackFram/llama-68m"
|
model_name = "JackFram/llama-68m"
|
||||||
server_args = ["--enforce-eager"]
|
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
res = await client.audio.transcriptions.create(model=model_name,
|
res = await client.audio.transcriptions.create(model=model_name,
|
||||||
file=winning_call,
|
file=winning_call,
|
||||||
@ -120,157 +90,149 @@ async def test_non_asr_model(winning_call):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_endpoints():
|
async def test_bad_requests(mary_had_lamb, client):
|
||||||
|
# invalid language
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.audio.transcriptions.create(model=MODEL_NAME,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="hh",
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_audio_request(mary_had_lamb, client):
|
||||||
|
mary_had_lamb.seek(0)
|
||||||
|
audio, sr = librosa.load(mary_had_lamb)
|
||||||
|
# Add small silence after each audio for repeatability in the split process
|
||||||
|
audio = np.pad(audio, (0, 1600))
|
||||||
|
repeated_audio = np.tile(audio, 10)
|
||||||
|
# Repeated audio to buffer
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||||
|
buffer.seek(0)
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
file=buffer,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0)
|
||||||
|
out = json.loads(transcription)['text']
|
||||||
|
counts = out.count("Mary had a little lamb")
|
||||||
|
assert counts == 10, counts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_endpoints(client):
|
||||||
# text to text model
|
# text to text model
|
||||||
model_name = "openai/whisper-small"
|
res = await client.chat.completions.create(
|
||||||
server_args = ["--enforce-eager"]
|
model=MODEL_NAME,
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
messages=[{
|
||||||
client = remote_server.get_async_client()
|
"role": "system",
|
||||||
res = await client.chat.completions.create(
|
"content": "You are a helpful assistant."
|
||||||
model=model_name,
|
}])
|
||||||
messages=[{
|
err = res.error
|
||||||
"role": "system",
|
assert err["code"] == 400
|
||||||
"content": "You are a helpful assistant."
|
assert err["message"] == "The model does not support Chat Completions API"
|
||||||
}])
|
|
||||||
err = res.error
|
|
||||||
assert err["code"] == 400
|
|
||||||
assert err[
|
|
||||||
"message"] == "The model does not support Chat Completions API"
|
|
||||||
|
|
||||||
res = await client.completions.create(model=model_name, prompt="Hello")
|
res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
|
||||||
err = res.error
|
err = res.error
|
||||||
assert err["code"] == 400
|
assert err["code"] == 400
|
||||||
assert err["message"] == "The model does not support Completions API"
|
assert err["message"] == "The model does not support Completions API"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streaming_response(winning_call):
|
async def test_streaming_response(winning_call, client):
|
||||||
model_name = "openai/whisper-small"
|
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
transcription = ""
|
transcription = ""
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
res_no_stream = await client.audio.transcriptions.create(
|
||||||
client = remote_server.get_async_client()
|
model=MODEL_NAME,
|
||||||
res_no_stream = await client.audio.transcriptions.create(
|
file=winning_call,
|
||||||
model=model_name,
|
response_format="json",
|
||||||
file=winning_call,
|
language="en",
|
||||||
response_format="json",
|
temperature=0.0)
|
||||||
language="en",
|
res = await client.audio.transcriptions.create(model=MODEL_NAME,
|
||||||
temperature=0.0)
|
file=winning_call,
|
||||||
# Unfortunately this only works when the openai client is patched
|
language="en",
|
||||||
# to use streaming mode, not exposed in the transcription api.
|
temperature=0.0,
|
||||||
original_post = AsyncAPIClient.post
|
stream=True,
|
||||||
|
timeout=30)
|
||||||
|
# Reconstruct from chunks and validate
|
||||||
|
async for chunk in res:
|
||||||
|
text = chunk.choices[0]['delta']['content']
|
||||||
|
transcription += text
|
||||||
|
|
||||||
async def post_with_stream(*args, **kwargs):
|
assert transcription == res_no_stream.text
|
||||||
kwargs['stream'] = True
|
|
||||||
return await original_post(*args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
res = await client.audio.transcriptions.create(
|
|
||||||
model=model_name,
|
|
||||||
file=winning_call,
|
|
||||||
language="en",
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body=dict(stream=True),
|
|
||||||
timeout=30)
|
|
||||||
# Reconstruct from chunks and validate
|
|
||||||
async for chunk in res:
|
|
||||||
# just a chunk
|
|
||||||
text = chunk.choices[0]['delta']['content']
|
|
||||||
transcription += text
|
|
||||||
|
|
||||||
assert transcription == res_no_stream.text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_options(winning_call):
|
async def test_stream_options(winning_call, client):
|
||||||
model_name = "openai/whisper-small"
|
res = await client.audio.transcriptions.create(
|
||||||
server_args = ["--enforce-eager"]
|
model=MODEL_NAME,
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
file=winning_call,
|
||||||
original_post = AsyncAPIClient.post
|
language="en",
|
||||||
|
temperature=0.0,
|
||||||
async def post_with_stream(*args, **kwargs):
|
stream=True,
|
||||||
kwargs['stream'] = True
|
extra_body=dict(stream_include_usage=True,
|
||||||
return await original_post(*args, **kwargs)
|
stream_continuous_usage_stats=True),
|
||||||
|
timeout=30)
|
||||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
final = False
|
||||||
client = remote_server.get_async_client()
|
continuous = True
|
||||||
res = await client.audio.transcriptions.create(
|
async for chunk in res:
|
||||||
model=model_name,
|
if not len(chunk.choices):
|
||||||
file=winning_call,
|
# final usage sent
|
||||||
language="en",
|
final = True
|
||||||
temperature=0.0,
|
else:
|
||||||
extra_body=dict(stream=True,
|
continuous = continuous and hasattr(chunk, 'usage')
|
||||||
stream_include_usage=True,
|
assert final and continuous
|
||||||
stream_continuous_usage_stats=True),
|
|
||||||
timeout=30)
|
|
||||||
final = False
|
|
||||||
continuous = True
|
|
||||||
async for chunk in res:
|
|
||||||
if not len(chunk.choices):
|
|
||||||
# final usage sent
|
|
||||||
final = True
|
|
||||||
else:
|
|
||||||
continuous = continuous and hasattr(chunk, 'usage')
|
|
||||||
assert final and continuous
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sampling_params(mary_had_lamb):
|
async def test_sampling_params(mary_had_lamb, client):
|
||||||
"""
|
"""
|
||||||
Compare sampling with params and greedy sampling to assert results
|
Compare sampling with params and greedy sampling to assert results
|
||||||
are different when extreme sampling parameters values are picked.
|
are different when extreme sampling parameters values are picked.
|
||||||
"""
|
"""
|
||||||
model_name = "openai/whisper-small"
|
transcription = await client.audio.transcriptions.create(
|
||||||
server_args = ["--enforce-eager"]
|
model=MODEL_NAME,
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
file=mary_had_lamb,
|
||||||
client = remote_server.get_async_client()
|
language="en",
|
||||||
transcription = await client.audio.transcriptions.create(
|
temperature=0.8,
|
||||||
model=model_name,
|
extra_body=dict(seed=42,
|
||||||
file=mary_had_lamb,
|
repetition_penalty=1.9,
|
||||||
language="en",
|
top_k=12,
|
||||||
temperature=0.8,
|
top_p=0.4,
|
||||||
extra_body=dict(seed=42,
|
min_p=0.5,
|
||||||
repetition_penalty=1.9,
|
frequency_penalty=1.8,
|
||||||
top_k=12,
|
presence_penalty=2.0))
|
||||||
top_p=0.4,
|
|
||||||
min_p=0.5,
|
|
||||||
frequency_penalty=1.8,
|
|
||||||
presence_penalty=2.0))
|
|
||||||
|
|
||||||
greedy_transcription = await client.audio.transcriptions.create(
|
greedy_transcription = await client.audio.transcriptions.create(
|
||||||
model=model_name,
|
model=MODEL_NAME,
|
||||||
file=mary_had_lamb,
|
file=mary_had_lamb,
|
||||||
language="en",
|
language="en",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
extra_body=dict(seed=42))
|
extra_body=dict(seed=42))
|
||||||
|
|
||||||
assert greedy_transcription.text != transcription.text
|
assert greedy_transcription.text != transcription.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_audio_prompt(mary_had_lamb):
|
async def test_audio_prompt(mary_had_lamb, client):
|
||||||
model_name = "openai/whisper-large-v3-turbo"
|
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
prompt = "This is a speech, recorded in a phonograph."
|
prompt = "This is a speech, recorded in a phonograph."
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
#Prompts should not omit the part of original prompt while transcribing.
|
||||||
#Prompts should not omit the part of original prompt while transcribing.
|
prefix = "The first words I spoke in the original phonograph"
|
||||||
prefix = "The first words I spoke in the original phonograph"
|
transcription = await client.audio.transcriptions.create(
|
||||||
client = remote_server.get_async_client()
|
model=MODEL_NAME,
|
||||||
transcription = await client.audio.transcriptions.create(
|
file=mary_had_lamb,
|
||||||
model=model_name,
|
language="en",
|
||||||
file=mary_had_lamb,
|
response_format="text",
|
||||||
language="en",
|
temperature=0.0)
|
||||||
response_format="text",
|
out = json.loads(transcription)['text']
|
||||||
temperature=0.0)
|
assert prefix in out
|
||||||
out = json.loads(transcription)['text']
|
transcription_wprompt = await client.audio.transcriptions.create(
|
||||||
assert prefix in out
|
model=MODEL_NAME,
|
||||||
transcription_wprompt = await client.audio.transcriptions.create(
|
file=mary_had_lamb,
|
||||||
model=model_name,
|
language="en",
|
||||||
file=mary_had_lamb,
|
response_format="text",
|
||||||
language="en",
|
prompt=prompt,
|
||||||
response_format="text",
|
temperature=0.0)
|
||||||
prompt=prompt,
|
out_prompt = json.loads(transcription_wprompt)['text']
|
||||||
temperature=0.0)
|
assert prefix in out_prompt
|
||||||
out_prompt = json.loads(transcription_wprompt)['text']
|
|
||||||
assert prefix in out_prompt
|
|
||||||
|
|||||||
@ -4,18 +4,21 @@
|
|||||||
import io
|
import io
|
||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from openai._base_client import AsyncAPIClient
|
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "openai/whisper-small"
|
||||||
|
SERVER_ARGS = ["--enforce-eager"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def foscolo():
|
def foscolo():
|
||||||
@ -25,50 +28,23 @@ def foscolo():
|
|||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
@pytest.fixture(scope="module")
|
||||||
@pytest.mark.asyncio
|
def server():
|
||||||
async def test_basic_audio(foscolo):
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
|
||||||
model_name = "openai/whisper-small"
|
yield remote_server
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
translation = await client.audio.translations.create(
|
|
||||||
model=model_name,
|
|
||||||
file=foscolo,
|
|
||||||
response_format="text",
|
|
||||||
# TODO remove once language detection is implemented
|
|
||||||
extra_body=dict(language="it"),
|
|
||||||
temperature=0.0)
|
|
||||||
out = json.loads(translation)['text'].strip().lower()
|
|
||||||
assert "greek sea" in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest_asyncio.fixture
|
||||||
async def test_audio_prompt(foscolo):
|
async def client(server):
|
||||||
model_name = "openai/whisper-small"
|
async with server.get_async_client() as async_client:
|
||||||
server_args = ["--enforce-eager"]
|
yield async_client
|
||||||
# Condition whisper on starting text
|
|
||||||
prompt = "Nor have I ever"
|
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
transcription = await client.audio.translations.create(
|
|
||||||
model=model_name,
|
|
||||||
file=foscolo,
|
|
||||||
prompt=prompt,
|
|
||||||
extra_body=dict(language="it"),
|
|
||||||
response_format="text",
|
|
||||||
temperature=0.0)
|
|
||||||
out = json.loads(transcription)['text']
|
|
||||||
assert "Nor will I ever touch the sacred" not in out
|
|
||||||
assert prompt not in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_asr_model(foscolo):
|
async def test_non_asr_model(foscolo):
|
||||||
# text to text model
|
# text to text model
|
||||||
model_name = "JackFram/llama-68m"
|
model_name = "JackFram/llama-68m"
|
||||||
server_args = ["--enforce-eager"]
|
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
res = await client.audio.translations.create(model=model_name,
|
res = await client.audio.translations.create(model=model_name,
|
||||||
file=foscolo,
|
file=foscolo,
|
||||||
@ -78,81 +54,117 @@ async def test_non_asr_model(foscolo):
|
|||||||
assert err["message"] == "The model does not support Translations API"
|
assert err["message"] == "The model does not support Translations API"
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streaming_response(foscolo):
|
async def test_basic_audio(foscolo, client):
|
||||||
model_name = "openai/whisper-small"
|
translation = await client.audio.translations.create(
|
||||||
server_args = ["--enforce-eager"]
|
model=MODEL_NAME,
|
||||||
|
file=foscolo,
|
||||||
|
response_format="text",
|
||||||
|
# TODO remove once language detection is implemented
|
||||||
|
extra_body=dict(language="it"),
|
||||||
|
temperature=0.0)
|
||||||
|
out = json.loads(translation)['text'].strip().lower()
|
||||||
|
assert "greek sea" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_prompt(foscolo, client):
|
||||||
|
# Condition whisper on starting text
|
||||||
|
prompt = "Nor have I ever"
|
||||||
|
transcription = await client.audio.translations.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
file=foscolo,
|
||||||
|
prompt=prompt,
|
||||||
|
extra_body=dict(language="it"),
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0)
|
||||||
|
out = json.loads(transcription)['text']
|
||||||
|
assert "Nor will I ever touch the sacred" not in out
|
||||||
|
assert prompt not in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_response(foscolo, client, server):
|
||||||
translation = ""
|
translation = ""
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
res_no_stream = await client.audio.translations.create(
|
||||||
client = remote_server.get_async_client()
|
model=MODEL_NAME,
|
||||||
res_no_stream = await client.audio.translations.create(
|
file=foscolo,
|
||||||
model=model_name,
|
response_format="json",
|
||||||
file=foscolo,
|
extra_body=dict(language="it"),
|
||||||
response_format="json",
|
temperature=0.0)
|
||||||
extra_body=dict(language="it"),
|
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
|
||||||
temperature=0.0)
|
url = server.url_for("v1/audio/translations")
|
||||||
# Unfortunately this only works when the openai client is patched
|
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||||
# to use streaming mode, not exposed in the translation api.
|
data = {
|
||||||
original_post = AsyncAPIClient.post
|
"model": MODEL_NAME,
|
||||||
|
"language": "it",
|
||||||
|
"stream": True,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
foscolo.seek(0)
|
||||||
|
async with httpx.AsyncClient() as http_client:
|
||||||
|
files = {"file": foscolo}
|
||||||
|
async with http_client.stream("POST",
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
files=files) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[len("data: "):]
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
chunk = json.loads(line)
|
||||||
|
text = chunk["choices"][0].get("delta", {}).get("content")
|
||||||
|
translation += text or ""
|
||||||
|
|
||||||
async def post_with_stream(*args, **kwargs):
|
assert translation == res_no_stream.text
|
||||||
kwargs['stream'] = True
|
|
||||||
return await original_post(*args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
res = await client.audio.translations.create(model=model_name,
|
|
||||||
file=foscolo,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body=dict(
|
|
||||||
stream=True,
|
|
||||||
language="it"))
|
|
||||||
# Reconstruct from chunks and validate
|
|
||||||
async for chunk in res:
|
|
||||||
# just a chunk
|
|
||||||
text = chunk.choices[0]['delta']['content']
|
|
||||||
translation += text
|
|
||||||
|
|
||||||
assert translation == res_no_stream.text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_options(foscolo):
|
async def test_stream_options(foscolo, client, server):
|
||||||
model_name = "openai/whisper-small"
|
url = server.url_for("v1/audio/translations")
|
||||||
server_args = ["--enforce-eager"]
|
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
data = {
|
||||||
original_post = AsyncAPIClient.post
|
"model": MODEL_NAME,
|
||||||
|
"language": "it",
|
||||||
async def post_with_stream(*args, **kwargs):
|
"stream": True,
|
||||||
kwargs['stream'] = True
|
"stream_include_usage": True,
|
||||||
return await original_post(*args, **kwargs)
|
"stream_continuous_usage_stats": True,
|
||||||
|
"temperature": 0.0,
|
||||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
}
|
||||||
client = remote_server.get_async_client()
|
foscolo.seek(0)
|
||||||
res = await client.audio.translations.create(
|
final = False
|
||||||
model=model_name,
|
continuous = True
|
||||||
file=foscolo,
|
async with httpx.AsyncClient() as http_client:
|
||||||
temperature=0.0,
|
files = {"file": foscolo}
|
||||||
extra_body=dict(language="it",
|
async with http_client.stream("POST",
|
||||||
stream=True,
|
url,
|
||||||
stream_include_usage=True,
|
headers=headers,
|
||||||
stream_continuous_usage_stats=True))
|
data=data,
|
||||||
final = False
|
files=files) as response:
|
||||||
continuous = True
|
async for line in response.aiter_lines():
|
||||||
async for chunk in res:
|
if not line:
|
||||||
if not len(chunk.choices):
|
continue
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[len("data: "):]
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
chunk = json.loads(line)
|
||||||
|
choices = chunk.get("choices", [])
|
||||||
|
if not choices:
|
||||||
# final usage sent
|
# final usage sent
|
||||||
final = True
|
final = True
|
||||||
else:
|
else:
|
||||||
continuous = continuous and hasattr(chunk, 'usage')
|
continuous = continuous and ("usage" in chunk)
|
||||||
assert final and continuous
|
assert final and continuous
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_audio_request(foscolo):
|
async def test_long_audio_request(foscolo, client):
|
||||||
model_name = "openai/whisper-small"
|
|
||||||
server_args = ["--enforce-eager"]
|
|
||||||
|
|
||||||
foscolo.seek(0)
|
foscolo.seek(0)
|
||||||
audio, sr = librosa.load(foscolo)
|
audio, sr = librosa.load(foscolo)
|
||||||
repeated_audio = np.tile(audio, 2)
|
repeated_audio = np.tile(audio, 2)
|
||||||
@ -160,13 +172,11 @@ async def test_long_audio_request(foscolo):
|
|||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
translation = await client.audio.translations.create(
|
||||||
client = remote_server.get_async_client()
|
model=MODEL_NAME,
|
||||||
translation = await client.audio.translations.create(
|
file=buffer,
|
||||||
model=model_name,
|
extra_body=dict(language="it"),
|
||||||
file=buffer,
|
response_format="text",
|
||||||
extra_body=dict(language="it"),
|
temperature=0.0)
|
||||||
response_format="text",
|
out = json.loads(translation)['text'].strip().lower()
|
||||||
temperature=0.0)
|
assert out.count("greek sea") == 2
|
||||||
out = json.loads(translation)['text'].strip().lower()
|
|
||||||
assert out.count("greek sea") == 2
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user