diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index dab14f1d7d03f..e1d175d9c6e12 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -37,7 +37,6 @@ async def test_basic_audio(mary_had_lamb): model_name = "openai/whisper-large-v3-turbo" server_args = ["--enforce-eager"] # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. - prompt = "THE FIRST WORDS I SPOKE" with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() transcription = await client.audio.transcriptions.create( @@ -48,16 +47,6 @@ async def test_basic_audio(mary_had_lamb): temperature=0.0) out = json.loads(transcription)['text'] assert "Mary had a little lamb," in out - # This should "force" whisper to continue prompt in all caps - transcription_wprompt = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - response_format="text", - prompt=prompt, - temperature=0.0) - out_capital = json.loads(transcription_wprompt)['text'] - assert prompt not in out_capital @pytest.mark.asyncio @@ -238,3 +227,31 @@ async def test_sampling_params(mary_had_lamb): extra_body=dict(seed=42)) assert greedy_transcription.text != transcription.text + + +@pytest.mark.asyncio +async def test_audio_prompt(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + prompt = "This is a speech, recorded in a phonograph." + with RemoteOpenAIServer(model_name, server_args) as remote_server: + #Prompts should not omit the part of original prompt while transcribing. + prefix = "The first words I spoke in the original phonograph" + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert prefix in out + transcription_wprompt = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_prompt = json.loads(transcription_wprompt)['text'] + assert prefix in out_prompt diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 27b3c75513fbf..344d6fc8f452f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -780,8 +780,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, @classmethod def get_decoder_prompt(cls, language: str, task_type: str, prompt: str) -> str: - return (f"<|startoftranscript|><|{language}|><|{task_type}|>" - f"<|notimestamps|>{prompt}") + return ((f"<|prev|>{prompt}" if prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>") @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: