[Core] Add engine option to return only deltas or final output (#7381)

This commit is contained in:
Nick Hill 2024-09-12 20:02:00 +01:00 committed by GitHub
parent a6c0f3658d
commit 551ce01078
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 371 additions and 137 deletions

View File

@ -50,6 +50,7 @@ steps:
- tests/worker - tests/worker
commands: commands:
- pytest -v -s async_engine # Async Engine - pytest -v -s async_engine # Async Engine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s multimodal - pytest -v -s multimodal
- pytest -v -s test_utils.py # Utils - pytest -v -s test_utils.py # Utils

View File

@ -1,7 +1,10 @@
import asyncio import asyncio
import os
import uuid
from asyncio import CancelledError from asyncio import CancelledError
from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import List, Optional
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -11,6 +14,7 @@ from vllm import SamplingParams
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind
from ..conftest import cleanup from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear from ..utils import wait_for_gpu_memory_to_clear
@ -122,8 +126,17 @@ def start_engine():
timeout_s=60, timeout_s=60,
) )
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
return AsyncLLMEngine.from_engine_args( return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) AsyncEngineArgs(model="facebook/opt-125m",
enforce_eager=True,
num_scheduler_steps=num_scheduler_steps))
def uid() -> str:
return str(uuid.uuid4())
@pytest_asyncio.fixture(scope="module") @pytest_asyncio.fixture(scope="module")
@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine): async def test_asyncio_run(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
async def run(prompt: str): async def run(prompt: str):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
min_tokens=32,
) )
output_count = 0
final_output = None
async for output in async_engine.generate(prompt, async for output in async_engine.generate(prompt,
sampling_params, sampling_params,
request_id=prompt): request_id=uid()):
output_count += 1
final_output = output final_output = output
return final_output return final_output, output_count
results = await asyncio.gather( results = await asyncio.gather(
run("test0"), run("test0"),
run("test1"), run("test0"),
) )
assert len(results) == 2 assert len(results) == 2
first, second = results
# remove nondeterministic fields for comparison
first[0].metrics = None
second[0].metrics = None
first[0].request_id = None
second[0].request_id = None
assert str(first) == str(second)
output_count = results[0][1]
if num_scheduler_steps == 1:
assert output_count == 32
else:
assert 1 < output_count < 32
@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
)
async def run(prompt: str, kind: RequestOutputKind):
params = copy(sampling_params)
params.output_kind = kind
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
output_count += 1
final_output = output
assert final_output is not None
return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
async def run_deltas(prompt: str):
params = copy(sampling_params)
params.output_kind = RequestOutputKind.DELTA
prompt_tokens = None
output_tokens: List[int] = []
output_text = ""
output_count = 0
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
prompt_tokens = output.prompt_token_ids
output_tokens.extend(token_ids)
output_text += text
output_count += 1
return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather(
run("common input prompt", RequestOutputKind.CUMULATIVE),
run("common input prompt", RequestOutputKind.FINAL_ONLY),
run_deltas("common input prompt"))
# Make sure outputs are the same
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
assert len(prompt_set) == 1
text_set = set(text for _, _, text, _ in results)
assert len(text_set) == 1
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
assert len(tokens_set) == 1
cumulative, final, deltas = results
# output message counts
assert cumulative[3] == deltas[3]
if num_scheduler_steps == 1:
assert cumulative[3] == 32
else:
assert 1 < cumulative[3] < 32
assert final[3] == 1
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine): async def test_cancellation(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
min_tokens=10, min_tokens=13,
max_tokens=10, max_tokens=13,
) )
stop_at = 5 if num_scheduler_steps == 1 else 1
request_id = uid()
i = 0 i = 0
with pytest.raises(CancelledError): with pytest.raises(CancelledError):
async for output in async_engine.generate("test2", async for output in async_engine.generate("test2",
sampling_params, sampling_params,
request_id="test2"): request_id=request_id):
assert not output.finished assert not output.finished
i += 1 i += 1
if i == 5: if i == stop_at:
await async_engine.abort("test2") await async_engine.abort(request_id)
assert i == 5 assert i == stop_at
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine): async def test_delayed_generator(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
min_tokens=10, min_tokens=10,
max_tokens=10, max_tokens=10,
) )
stream = async_engine.generate("test3", stream = async_engine.generate("test3", sampling_params, request_id=uid())
sampling_params,
request_id="test3")
i = 0 i = 0
final_output: Optional[RealRequestOutput] = None final_output: Optional[RealRequestOutput] = None
async for output in stream: async for output in stream:

View File

@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
@ -225,9 +225,6 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
@ -295,7 +292,6 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig( self.observability_config = observability_config or ObservabilityConfig(
) )
self.log_stats = log_stats self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
@ -1378,7 +1374,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) if request_output:
ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time, # When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output) # and invoke the request output callback (if there was final output)
@ -1415,14 +1412,19 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished() request_output = RequestOutputFactory.create(seq_group)
if self.step_return_finished_only else True): if request_output:
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) if request_output:
ctx.request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given) # Immediately process request outputs here (if callback is given)
if (ctx.request_outputs if (ctx.request_outputs

View File

@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -642,14 +642,12 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request " raise ValueError("The lengths of prompts and lora_request "
"must be the same.") "must be the same.")
if isinstance(params, list): for sp in params if isinstance(params, list) else (params, ):
params = [ if isinstance(sp, SamplingParams):
self._add_guided_processor(param, guided_options) self._add_guided_processor(sp, guided_options)
if isinstance(param, SamplingParams) else param
for param in params # We only care about the final output
] sp.output_kind = RequestOutputKind.FINAL_ONLY
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
@ -709,9 +707,6 @@ class LLM:
f"output: {0:.2f} toks/s"), f"output: {0:.2f} toks/s"),
) )
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
@ -724,6 +719,7 @@ class LLM:
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"] in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum( total_out_toks += sum(
@ -735,9 +731,6 @@ class LLM:
f"output: {out_spd:.2f} toks/s") f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.

View File

@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")

View File

@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
return self.response_role return self.response_role
else: return request.messages[-1]["role"]
return request.messages[-1]["role"]
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, self,
@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser( tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
try: try:
async for res in result_generator: async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
# if continuous usage stats are requested, add it # if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats: if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(
usage = UsageInfo(prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=0, completion_tokens=0,
total_tokens=prompt_tokens) total_tokens=num_prompt_tokens)
chunk.usage = usage chunk.usage = usage
# otherwise don't # otherwise don't
else: else:
@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request.stream_options.include_usage): request.stream_options.include_usage):
if (request.stream_options. if (request.stream_options.
continuous_usage_stats): continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=0, completion_tokens=0,
total_tokens=prompt_tokens) total_tokens=num_prompt_tokens)
chunk.usage = usage chunk.usage = usage
else: else:
chunk.usage = None chunk.usage = None
@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False first_iteration = False
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index
if finish_reason_sent[i]: if finish_reason_sent[i]:
continue continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs and request.top_logprobs is not None: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, ( assert output.logprobs is not None, (
"Did not output logprobs") "Did not output logprobs")
logprobs = self._create_chat_logprobs( logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids, token_ids=output.token_ids,
top_logprobs=out_logprobs, top_logprobs=output.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
logprobs = None logprobs = None
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text
delta_message: Optional[DeltaMessage] = None delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is if tool_choice_function_name:
ChatCompletionNamedToolChoiceParam):
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall( DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name, name=tool_choice_function_name,
arguments=delta_text), arguments=delta_text),
index=i) index=i)
]) ])
# handle streaming deltas for tools with "auto" tool choice # handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request) elif tool_choice_auto:
and tool_parser): assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = ( delta_message = (
tool_parser.extract_tool_calls_streaming( tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i], previous_text=previous_text,
current_text=output.text, current_text=current_text,
delta_text=delta_text, delta_text=delta_text,
previous_token_ids= \ previous_token_ids=previous_token_ids,
output.token_ids[ current_token_ids=current_token_ids,
:-1 * len(delta_token_ids) delta_token_ids=output.token_ids))
],
current_token_ids=output.token_ids, # update the previous values for the next iteration
delta_token_ids=delta_token_ids previous_texts[i] = current_text
) all_previous_token_ids[i] = current_token_ids
)
# handle streaming just a content delta # handle streaming just a content delta
else: else:
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration # set the previous values for the next iteration
previous_texts[i] = output.text previous_num_tokens[i] += len(output.token_ids)
previous_num_tokens[i] = len(output.token_ids)
# if the message delta is None (e.g. because it was a # if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise # "control token" for tool calls or the parser otherwise
@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous # handle usage stats if requested & if continuous
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats): if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + total_tokens=num_prompt_tokens +
completion_tokens, completion_tokens,
) )
chunk.usage = usage chunk.usage = usage
@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser.prev_tool_call_arr[index].get( tool_parser.prev_tool_call_arr[index].get(
"arguments", {})) "arguments", {}))
# get what we've streamed so for for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool
actual_call = tool_parser.streamed_args_for_tool[ actual_call = tool_parser.streamed_args_for_tool[
index] index]
@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
]) ])
# Send the finish response for each request.n only once # Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=delta_message, delta=delta_message,
@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats): if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + total_tokens=num_prompt_tokens +
completion_tokens, completion_tokens,
) )
chunk.usage = usage chunk.usage = usage
@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage # is sent, send the usage
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo( final_usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=previous_num_tokens[i], completion_tokens=completion_tokens,
total_tokens=prompt_tokens + previous_num_tokens[i], total_tokens=num_prompt_tokens + completion_tokens,
) )
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or "") or "")
choice.message.content = full_message choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum( num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return bool( return bool(
# if there is a delta message that includes tool calls which # if there is a delta message that includes tool calls which
# include a function that has arguments # include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0] and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
) )

View File

@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = res.prompt_logprobs prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int] delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[ out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]] int, Logprob]]]]
@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = prompt_text delta_text = prompt_text
@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
assert prompt_logprobs is not None assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text
delta_token_ids = output.token_ids[ delta_token_ids = output.token_ids
previous_num_tokens[i]:] out_logprobs = output.logprobs
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert out_logprobs is not None, ( assert out_logprobs is not None, (
@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]), initial_text_offset=previous_text_lens[i],
) )
else: else:
logprobs = None logprobs = None
previous_texts[i] = output.text previous_text_lens[i] += len(output.text)
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason finish_reason = output.finish_reason
stop_reason = output.stop_reason stop_reason = output.stop_reason
@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(prompt_token_ids) prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = len(output.token_ids) completion_tokens = previous_num_tokens[i]
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch: for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
) )
choices.append(choice_data) choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids) num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus) SequenceGroup, SequenceStatus)
@ -92,7 +93,7 @@ class RequestOutput:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
@ -113,19 +114,26 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls,
if seq_group.sampling_params is None: seq_group: SequenceGroup) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError( raise ValueError(
"Sampling parameters are missing for a CompletionRequest.") "Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
if len(seqs) == 1: if len(seqs) == 1:
top_n_seqs = seqs top_n_seqs = seqs
else: else:
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = sampling_params.n
if seq_group.sampling_params.use_beam_search: if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score( sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty) sampling_params.length_penalty)
else: else:
sorting_key = lambda seq: seq.get_cumulative_logprob() sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence # NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None include_logprobs = sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length text_buffer_length = sampling_params.output_text_buffer_length
outputs = [ delta = sampling_params.output_kind == RequestOutputKind.DELTA
CompletionOutput(
seqs.index(seq), outputs = []
seq.get_output_text_to_return(text_buffer_length), include_prompt = True
seq.data._output_token_ids, for seq in top_n_seqs:
seq.get_cumulative_logprob() if include_logprobs else None, output_text = seq.get_output_text_to_return(
seq.output_logprobs if include_logprobs else None, text_buffer_length, delta)
SequenceStatus.get_finished_reason(seq.status), output_token_ids = seq.get_output_token_ids_to_return(delta)
seq.stop_reason) for seq in top_n_seqs output_logprobs = seq.output_logprobs if include_logprobs else None
]
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
include_prompt = False
outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt if include_prompt:
prompt_token_ids = seq_group.prompt_token_ids prompt = seq_group.prompt
encoder_prompt = seq_group.encoder_prompt prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids encoder_prompt = seq_group.encoder_prompt
prompt_logprobs = seq_group.prompt_logprobs encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
finished = seq_group.is_finished() prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time) seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id, return cls(seq_group.request_id,

View File

@ -1,6 +1,6 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
from enum import IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from.""" to sample from."""
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams( class SamplingParams(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
@ -147,6 +156,7 @@ class SamplingParams(
logits_processors: Optional[Any] = None logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input. # The below fields are not supposed to be used as an input.
# They are set in post_init. # They are set in post_init.
@ -182,6 +192,7 @@ class SamplingParams(
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None, msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
) -> "SamplingParams": ) -> "SamplingParams":
return SamplingParams( return SamplingParams(
n=1 if n is None else n, n=1 if n is None else n,
@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -317,6 +329,9 @@ class SamplingParams(
raise ValueError( raise ValueError(
"stop strings are only supported when detokenize is True. " "stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.") "Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:

View File

@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
Optional, Set, Tuple, Union, cast) from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
import msgspec import msgspec
import torch import torch
@ -407,6 +408,10 @@ class Sequence:
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_text_offset: int = 0
# Used for incremental detokenization # Used for incremental detokenization
self.prefix_offset = 0 self.prefix_offset = 0
self.read_offset = 0 self.read_offset = 0
@ -462,11 +467,35 @@ class Sequence:
return self.prompt_adapter_request.prompt_adapter_id \ return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0 if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int): def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished. # We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished() truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else ( if not delta:
self.output_text) return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size # TODO This can produce incorrect hash when block size > prompt size