Mark prompt logprobs as incompatible with prompt embeds at API level (#25077)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Andrew Sansom 2025-09-17 23:25:07 -05:00 committed by GitHub
parent 52bc9d5b3e
commit bec060fd99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 4 deletions

View File

@ -228,3 +228,20 @@ async def test_completions_with_logprobs_and_prompt_embeds(
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
@pytest.mark.asyncio
async def test_prompt_logprobs_raises_error(
client_with_prompt_embeds: openai.AsyncOpenAI):
with pytest.raises(BadRequestError, match="not compatible"):
encoded_embeds = create_dummy_embeds()
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="",
max_tokens=5,
temperature=0.0,
extra_body={
"prompt_embeds": encoded_embeds,
"prompt_logprobs": True
},
)

View File

@ -671,10 +671,13 @@ class LLMEngine:
arrival_time = time.time()
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
and prompt.get("prompt_embeds", None) is not None):
if not prompt.get("prompt_token_ids", None):
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
if params.prompt_logprobs is not None:
raise ValueError(
"prompt_logprobs is not compatible with prompt embeds.")
processed_inputs = self.input_preprocessor.preprocess(
prompt,

View File

@ -112,6 +112,11 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(
"Echo is unsupported with prompt embeds.")
if (request.prompt_logprobs is not None
and request.prompt_embeds is not None):
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds.")
request_id = (
f"cmpl-"
f"{self._base_request_id(raw_request, request.request_id)}")