From fd85c9f4263de8cf9bc9f51bef9471344436614c Mon Sep 17 00:00:00 2001 From: Max Wittig Date: Tue, 14 Oct 2025 09:17:39 +0200 Subject: [PATCH] [Bugfix][FE]: Always include usage with `--enable-force-include-usage ` (#20983) Signed-off-by: Max Wittig Signed-off-by: Antoine Auger Co-authored-by: Antoine Auger --- pyproject.toml | 1 + .../openai/test_enable_force_include_usage.py | 126 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 2 + vllm/entrypoints/openai/run_batch.py | 8 ++ vllm/entrypoints/openai/serving_chat.py | 15 +-- vllm/entrypoints/openai/serving_completion.py | 16 +-- vllm/entrypoints/openai/serving_engine.py | 3 - vllm/entrypoints/openai/serving_responses.py | 1 - .../openai/serving_transcription.py | 4 + vllm/entrypoints/openai/speech_to_text.py | 7 +- vllm/entrypoints/utils.py | 19 ++- 11 files changed, 172 insertions(+), 30 deletions(-) create mode 100644 tests/entrypoints/openai/test_enable_force_include_usage.py diff --git a/pyproject.toml b/pyproject.toml index eb9bdb593baa..95dda76063bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ markers = [ "distributed: run this test only in distributed GPU tests", "skip_v1: do not run this test with v1", "optional: optional tests that are automatically skipped, include --optional to run them", + "extra_server_args: extra arguments to pass to the server fixture", ] [tool.ty.src] diff --git a/tests/entrypoints/openai/test_enable_force_include_usage.py b/tests/entrypoints/openai/test_enable_force_include_usage.py new file mode 100644 index 000000000000..3ddf2308eb1d --- /dev/null +++ b/tests/entrypoints/openai/test_enable_force_include_usage.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture(scope="module") +def chat_server_with_force_include_usage(request): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "128", + "--enforce-eager", + "--max-num-seqs", + "1", + "--enable-force-include-usage", + "--port", + "55857", + "--gpu-memory-utilization", + "0.2", + ] + + with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def chat_client_with_force_include_usage(chat_server_with_force_include_usage): + async with chat_server_with_force_include_usage.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_chat_with_enable_force_include_usage( + chat_client_with_force_include_usage: openai.AsyncOpenAI, +): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + stream = await chat_client_with_force_include_usage.chat.completions.create( + model="Qwen/Qwen3-0.6B", + messages=messages, + max_completion_tokens=10, + extra_body=dict(min_tokens=10), + temperature=0.0, + stream=True, + ) + last_completion_tokens = 0 + async for chunk in stream: + if not len(chunk.choices): + assert chunk.usage.prompt_tokens >= 0 + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens + ) + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) + else: + assert chunk.usage is None + + +@pytest.fixture(scope="module") +def transcription_server_with_force_include_usage(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-num-seqs", + "1", + "--enforce-eager", + "--enable-force-include-usage", + "--gpu-memory-utilization", + "0.2", + ] + + with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def transcription_client_with_force_include_usage( + transcription_server_with_force_include_usage, +): + async with ( + transcription_server_with_force_include_usage.get_async_client() as async_client + ): + yield async_client + + +@pytest.mark.asyncio +async def test_transcription_with_enable_force_include_usage( + transcription_client_with_force_include_usage, winning_call +): + res = ( + await transcription_client_with_force_include_usage.audio.transcriptions.create( + model="openai/whisper-large-v3-turbo", + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30, + ) + ) + + async for chunk in res: + if not len(chunk.choices): + # final usage sent + usage = chunk.usage + assert isinstance(usage, dict) + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + else: + assert not hasattr(chunk, "usage") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ec5632523fe3..fd80ba7a9afc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1808,6 +1808,7 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks else None @@ -1818,6 +1819,7 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks else None diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index ecee27a329d2..c8ca6e7d29ba 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -104,6 +104,13 @@ def make_arg_parser(parser: FlexibleArgumentParser): default=False, help="If set to True, enable prompt_tokens_details in usage.", ) + parser.add_argument( + "--enable-force-include-usage", + action="store_true", + default=False, + help="If set to True, include usage on every request " + "(even when stream_options is not specified)", + ) return parser @@ -361,6 +368,7 @@ async def run_batch( chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, ) if "generate" in supported_tasks else None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 96525f206859..26027112eb58 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -58,7 +58,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall -from vllm.entrypoints.utils import get_max_tokens +from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob @@ -101,7 +101,6 @@ class OpenAIServingChat(OpenAIServing): models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) @@ -352,7 +351,6 @@ class OpenAIServingChat(OpenAIServing): conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage, ) try: @@ -518,7 +516,6 @@ class OpenAIServingChat(OpenAIServing): conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" @@ -596,13 +593,9 @@ class OpenAIServingChat(OpenAIServing): return stream_options = request.stream_options - if stream_options: - include_usage = stream_options.include_usage or enable_force_include_usage - include_continuous_usage = ( - include_usage and stream_options.continuous_usage_stats - ) - else: - include_usage, include_continuous_usage = False, False + include_usage, include_continuous_usage = should_include_usage( + stream_options, self.enable_force_include_usage + ) try: async for res in result_generator: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7af64306023a..7cbe9c69435c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -27,7 +27,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig -from vllm.entrypoints.utils import get_max_tokens +from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob @@ -56,11 +56,11 @@ class OpenAIServingCompletion(OpenAIServing): models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = self.model_config.get_diff_sampling_param() + self.enable_force_include_usage = enable_force_include_usage if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source @@ -256,7 +256,6 @@ class OpenAIServingCompletion(OpenAIServing): num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata, - enable_force_include_usage=self.enable_force_include_usage, ) # Non-streaming response @@ -320,7 +319,6 @@ class OpenAIServingCompletion(OpenAIServing): num_prompts: int, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -331,13 +329,9 @@ class OpenAIServingCompletion(OpenAIServing): first_iteration = True stream_options = request.stream_options - if stream_options: - include_usage = stream_options.include_usage or enable_force_include_usage - include_continuous_usage = ( - include_usage and stream_options.continuous_usage_stats - ) - else: - include_usage, include_continuous_usage = False, False + include_usage, include_continuous_usage = should_include_usage( + stream_options, self.enable_force_include_usage + ) try: async for prompt_idx, res in result_generator: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a041950ffd20..3965d2dac088 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -249,7 +249,6 @@ class OpenAIServing: *, request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, - enable_force_include_usage: bool = False, log_error_stack: bool = False, ): super().__init__() @@ -260,8 +259,6 @@ class OpenAIServing: self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self.enable_force_include_usage = enable_force_include_usage - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._apply_mistral_chat_template_async = make_async( apply_mistral_chat_template, executor=self._tokenizer_executor diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 3b9015efd305..744df98a4278 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -127,7 +127,6 @@ class OpenAIServingResponses(OpenAIServing): models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index d043f55648d2..33da7034afab 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -37,6 +37,7 @@ class OpenAIServingTranscription(OpenAISpeechToText): request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, @@ -45,6 +46,7 @@ class OpenAIServingTranscription(OpenAISpeechToText): return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="transcribe", log_error_stack=log_error_stack, + enable_force_include_usage=enable_force_include_usage, ) async def create_transcription( @@ -96,6 +98,7 @@ class OpenAIServingTranslation(OpenAISpeechToText): request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, @@ -104,6 +107,7 @@ class OpenAIServingTranslation(OpenAISpeechToText): return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="translate", log_error_stack=log_error_stack, + enable_force_include_usage=enable_force_include_usage, ) async def create_translation( diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index fa6e962a1dd7..e012f43260c2 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -58,6 +58,7 @@ class OpenAISpeechToText(OpenAIServing): return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, @@ -74,6 +75,8 @@ class OpenAISpeechToText(OpenAIServing): self.model_config, task_type ) + self.enable_force_include_usage = enable_force_include_usage + self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB if self.default_sampling_params: @@ -261,9 +264,7 @@ class OpenAISpeechToText(OpenAIServing): completion_tokens = 0 num_prompt_tokens = 0 - include_usage = ( - request.stream_include_usage if request.stream_include_usage else False - ) + include_usage = self.enable_force_include_usage or request.stream_include_usage include_continuous_usage = ( request.stream_continuous_usage_stats if include_usage and request.stream_continuous_usage_stats diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 1504705cf0e2..c006a76d3cdf 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -14,7 +14,11 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + StreamOptions, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -237,3 +241,16 @@ def log_non_default_args(args: Namespace | EngineArgs): ) logger.info("non-default args: %s", non_default_args) + + +def should_include_usage( + stream_options: StreamOptions | None, enable_force_include_usage: bool +) -> tuple[bool, bool]: + if stream_options: + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = include_usage and bool( + stream_options.continuous_usage_stats + ) + else: + include_usage, include_continuous_usage = enable_force_include_usage, False + return include_usage, include_continuous_usage