mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
Mark prompt logprobs as incompatible with prompt embeds at API level (#25077)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
parent
52bc9d5b3e
commit
bec060fd99
@ -228,3 +228,20 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_logprobs_raises_error(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError, match="not compatible"):
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={
|
||||
"prompt_embeds": encoded_embeds,
|
||||
"prompt_logprobs": True
|
||||
},
|
||||
)
|
||||
|
||||
@ -671,10 +671,13 @@ class LLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if (isinstance(prompt, dict)
|
||||
and prompt.get("prompt_embeds", None) is not None
|
||||
and not prompt.get("prompt_token_ids", None)):
|
||||
and prompt.get("prompt_embeds", None) is not None):
|
||||
if not prompt.get("prompt_token_ids", None):
|
||||
seq_len = prompt["prompt_embeds"].shape[0]
|
||||
prompt["prompt_token_ids"] = [0] * seq_len
|
||||
if params.prompt_logprobs is not None:
|
||||
raise ValueError(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
|
||||
@ -112,6 +112,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
if (request.prompt_logprobs is not None
|
||||
and request.prompt_embeds is not None):
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
request_id = (
|
||||
f"cmpl-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user