[Bug][Frontend] Fix structure of transcription's decoder_prompt (#18809)

Signed-off-by: sangbumlikeagod <oironese@naver.com>
This commit is contained in:
sangbumlikeagod 2025-07-04 20:28:07 +09:00 committed by GitHub
parent 0e3fe896e2
commit 9e5452ee34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 13 deletions

View File

@ -37,7 +37,6 @@ async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
prompt = "THE FIRST WORDS I SPOKE"
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
@ -48,16 +47,6 @@ async def test_basic_audio(mary_had_lamb):
temperature=0.0)
out = json.loads(transcription)['text']
assert "Mary had a little lamb," in out
# This should "force" whisper to continue prompt in all caps
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_capital = json.loads(transcription_wprompt)['text']
assert prompt not in out_capital
@pytest.mark.asyncio
@ -238,3 +227,31 @@ async def test_sampling_params(mary_had_lamb):
extra_body=dict(seed=42))
assert greedy_transcription.text != transcription.text
@pytest.mark.asyncio
async def test_audio_prompt(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
prompt = "This is a speech, recorded in a phonograph."
with RemoteOpenAIServer(model_name, server_args) as remote_server:
#Prompts should not omit the part of original prompt while transcribing.
prefix = "The first words I spoke in the original phonograph"
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert prefix in out
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_prompt = json.loads(transcription_wprompt)['text']
assert prefix in out_prompt

View File

@ -780,8 +780,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
@classmethod
def get_decoder_prompt(cls, language: str, task_type: str,
prompt: str) -> str:
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
f"<|notimestamps|>{prompt}")
return ((f"<|prev|>{prompt}" if prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: