mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 13:24:31 +08:00
[Bugfix] Fix include prompt in stream response when echo=true (#15233)
Signed-off-by: Yuan Fang <yuanfang@alauda.io>
This commit is contained in:
parent
6d42ce8315
commit
e28533a16f
@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
|||||||
prompt="Give an example string that fits this regex",
|
prompt="Give an example string that fits this regex",
|
||||||
extra_body=dict(guided_regex=sample_regex,
|
extra_body=dict(guided_regex=sample_regex,
|
||||||
guided_json=sample_json_schema))
|
guided_json=sample_json_schema))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,stream,echo",
|
||||||
|
[
|
||||||
|
(MODEL_NAME, False, False),
|
||||||
|
(MODEL_NAME, False, True),
|
||||||
|
(MODEL_NAME, True, False),
|
||||||
|
(MODEL_NAME, True, True) # should not raise BadRequestError error
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, stream: bool,
|
||||||
|
echo: bool):
|
||||||
|
saying: str = "Hello, my name is"
|
||||||
|
result = await client.completions.create(model=model_name,
|
||||||
|
prompt=saying,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
echo=echo,
|
||||||
|
stream=stream)
|
||||||
|
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
completion = result
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 1
|
||||||
|
|
||||||
|
choice = completion.choices[0]
|
||||||
|
assert len(choice.text) >= 5
|
||||||
|
assert choice.finish_reason == stop_reason
|
||||||
|
|
||||||
|
if echo:
|
||||||
|
assert choice.text is not None and saying in choice.text
|
||||||
|
else:
|
||||||
|
assert choice.text is not None and saying not in choice.text
|
||||||
|
|
||||||
|
else:
|
||||||
|
chunks: list[str] = []
|
||||||
|
final_finish_reason = None
|
||||||
|
async for chunk in result:
|
||||||
|
if chunk.choices and chunk.choices[0].text:
|
||||||
|
chunks.append(chunk.choices[0].text)
|
||||||
|
if chunk.choices and chunk.choices[0].finish_reason:
|
||||||
|
final_finish_reason = chunk.choices[0].finish_reason
|
||||||
|
|
||||||
|
assert final_finish_reason == stop_reason
|
||||||
|
content = "".join(chunks)
|
||||||
|
if echo:
|
||||||
|
assert content is not None and saying in content
|
||||||
|
else:
|
||||||
|
assert content is not None and saying not in content
|
||||||
|
|||||||
@ -25,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
# yapf: enable
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
|
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
||||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||||
|
TextTokensPrompt,
|
||||||
clamp_prompt_logprobs,
|
clamp_prompt_logprobs,
|
||||||
is_text_tokens_prompt)
|
is_text_tokens_prompt)
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||||
is_tokens_prompt)
|
is_tokens_prompt)
|
||||||
@ -223,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
if stream:
|
if stream:
|
||||||
return self.completion_stream_generator(
|
return self.completion_stream_generator(
|
||||||
request,
|
request,
|
||||||
|
request_prompts,
|
||||||
result_generator,
|
result_generator,
|
||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
@ -285,6 +289,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
async def completion_stream_generator(
|
async def completion_stream_generator(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
|
request_prompts: list[Union[TextTokensPrompt,
|
||||||
|
ServingEngineEmbedsPrompt]],
|
||||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
created_time: int,
|
created_time: int,
|
||||||
@ -313,7 +319,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
async for prompt_idx, res in result_generator:
|
async for prompt_idx, res in result_generator:
|
||||||
prompt_token_ids = res.prompt_token_ids
|
prompt_token_ids = res.prompt_token_ids
|
||||||
prompt_logprobs = res.prompt_logprobs
|
prompt_logprobs = res.prompt_logprobs
|
||||||
prompt_text = res.prompt
|
|
||||||
|
if res.prompt is not None:
|
||||||
|
prompt_text = res.prompt
|
||||||
|
else:
|
||||||
|
request_prompt = request_prompts[prompt_idx]
|
||||||
|
if is_text_tokens_prompt(request_prompt):
|
||||||
|
prompt_text = request_prompt["prompt"]
|
||||||
|
else:
|
||||||
|
prompt_text = None
|
||||||
|
|
||||||
# Prompt details are excluded from later streamed outputs
|
# Prompt details are excluded from later streamed outputs
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
@ -336,14 +350,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
delta_token_ids = prompt_token_ids
|
delta_token_ids = prompt_token_ids
|
||||||
out_logprobs = prompt_logprobs
|
out_logprobs = prompt_logprobs
|
||||||
else:
|
else:
|
||||||
assert prompt_logprobs is not None
|
|
||||||
# echo the prompt and first token
|
# echo the prompt and first token
|
||||||
delta_text = prompt_text + output.text
|
delta_text = prompt_text + output.text
|
||||||
delta_token_ids = [
|
delta_token_ids = [
|
||||||
*prompt_token_ids, *output.token_ids
|
*prompt_token_ids, *output.token_ids
|
||||||
]
|
]
|
||||||
out_logprobs = [
|
out_logprobs = [
|
||||||
*prompt_logprobs,
|
*(prompt_logprobs or []),
|
||||||
*(output.logprobs or []),
|
*(output.logprobs or []),
|
||||||
]
|
]
|
||||||
has_echoed[i] = True
|
has_echoed[i] = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user