[Feature][Frontend]: Add support for stream_options in ChatCompletionRequest (#5135)

This commit is contained in:
Itay Etelis 2024-06-07 06:29:24 +03:00 committed by GitHub
parent 15063741e3
commit baa15a9ec3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 149 additions and 10 deletions

View File

@ -1343,5 +1343,106 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 17
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_stream_options(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options=None
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options=None,
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk
# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False},
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk
# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True},
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
chunks.append(chunk.choices[0].text)
else:
assert chunk.usage is None
finish_reason_count += 1
# The last message should have usage and no choices
last_message = await stream.__anext__()
assert last_message.usage is not None
assert last_message.usage.prompt_tokens > 0
assert last_message.usage.completion_tokens > 0
assert last_message.usage.total_tokens == (
last_message.usage.prompt_tokens +
last_message.usage.completion_tokens)
assert last_message.choices == []
# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None},
)
# Test stream=False, stream_options={"include_usage": False}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": False},
)
# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True},
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -102,6 +102,10 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool]
class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
@ -140,6 +144,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
@ -269,6 +274,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
logits_processors=logits_processors,
)
@model_validator(mode='before')
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
raise ValueError(
"stream_options can only be set if stream is true")
return values
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):

View File

@ -247,6 +247,9 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
@ -274,6 +277,9 @@ class OpenAIServingChat(OpenAIServing):
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
@ -327,17 +333,14 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
@ -350,12 +353,33 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))