[Feat][CLI] enforce-include-usage (#19695)

Signed-off-by: Max Wittig <max.wittig@siemens.com>
This commit is contained in:
Max Wittig 2025-06-25 07:43:04 +02:00 committed by GitHub
parent 879f69bed3
commit f59fc60fb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 9 deletions

View File

@ -1190,6 +1190,7 @@ async def init_app_state(
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
@ -1197,6 +1198,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,

View File

@ -272,6 +272,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
parser.add_argument(
"--enable-force-include-usage",
action='store_true',
default=False,
help="If set to True, including usage on every request.")
parser.add_argument(
"--enable-server-load-tracking",
action='store_true',

View File

@ -64,12 +64,14 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.response_role = response_role
self.chat_template = chat_template
@ -110,6 +112,7 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
@ -261,8 +264,14 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
request,
result_generator,
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
try:
return await self.chat_completion_full_generator(
@ -405,6 +414,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
@ -471,7 +481,8 @@ class OpenAIServingChat(OpenAIServing):
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_usage = stream_options.include_usage \
or enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:

View File

@ -52,12 +52,14 @@ class OpenAIServingCompletion(OpenAIServing):
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
@ -227,7 +229,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata)
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
@ -289,6 +292,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
@ -298,7 +302,8 @@ class OpenAIServingCompletion(OpenAIServing):
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_usage = stream_options.include_usage or \
enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:

View File

@ -132,7 +132,7 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
"""
Mixin for request processing,
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = []
@ -144,7 +144,7 @@ class RequestProcessingMixin(BaseModel):
class ResponseGenerationMixin(BaseModel):
"""
Mixin for response generation,
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: Optional[AsyncGenerator[tuple[int, Union[
@ -208,6 +208,7 @@ class OpenAIServing:
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__()
@ -219,6 +220,7 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)