mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:25:20 +08:00
[Bug][Frontend] Fix structure of transcription's decoder_prompt (#18809)
Signed-off-by: sangbumlikeagod <oironese@naver.com>
This commit is contained in:
parent
0e3fe896e2
commit
9e5452ee34
@ -37,7 +37,6 @@ async def test_basic_audio(mary_had_lamb):
|
|||||||
model_name = "openai/whisper-large-v3-turbo"
|
model_name = "openai/whisper-large-v3-turbo"
|
||||||
server_args = ["--enforce-eager"]
|
server_args = ["--enforce-eager"]
|
||||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||||
prompt = "THE FIRST WORDS I SPOKE"
|
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
transcription = await client.audio.transcriptions.create(
|
transcription = await client.audio.transcriptions.create(
|
||||||
@ -48,16 +47,6 @@ async def test_basic_audio(mary_had_lamb):
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
out = json.loads(transcription)['text']
|
out = json.loads(transcription)['text']
|
||||||
assert "Mary had a little lamb," in out
|
assert "Mary had a little lamb," in out
|
||||||
# This should "force" whisper to continue prompt in all caps
|
|
||||||
transcription_wprompt = await client.audio.transcriptions.create(
|
|
||||||
model=model_name,
|
|
||||||
file=mary_had_lamb,
|
|
||||||
language="en",
|
|
||||||
response_format="text",
|
|
||||||
prompt=prompt,
|
|
||||||
temperature=0.0)
|
|
||||||
out_capital = json.loads(transcription_wprompt)['text']
|
|
||||||
assert prompt not in out_capital
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -238,3 +227,31 @@ async def test_sampling_params(mary_had_lamb):
|
|||||||
extra_body=dict(seed=42))
|
extra_body=dict(seed=42))
|
||||||
|
|
||||||
assert greedy_transcription.text != transcription.text
|
assert greedy_transcription.text != transcription.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_prompt(mary_had_lamb):
|
||||||
|
model_name = "openai/whisper-large-v3-turbo"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
prompt = "This is a speech, recorded in a phonograph."
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
#Prompts should not omit the part of original prompt while transcribing.
|
||||||
|
prefix = "The first words I spoke in the original phonograph"
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0)
|
||||||
|
out = json.loads(transcription)['text']
|
||||||
|
assert prefix in out
|
||||||
|
transcription_wprompt = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.0)
|
||||||
|
out_prompt = json.loads(transcription_wprompt)['text']
|
||||||
|
assert prefix in out_prompt
|
||||||
|
|||||||
@ -780,8 +780,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||||
prompt: str) -> str:
|
prompt: str) -> str:
|
||||||
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
|
return ((f"<|prev|>{prompt}" if prompt else "") +
|
||||||
f"<|notimestamps|>{prompt}")
|
f"<|startoftranscript|><|{language}|>" +
|
||||||
|
f"<|{task_type}|><|notimestamps|>")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user