From a4113b035cd4ee90ec02de0658b75d47f0007ede Mon Sep 17 00:00:00 2001 From: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Date: Thu, 3 Jul 2025 23:50:17 -0300 Subject: [PATCH] [Platform] Add custom default max tokens (#18557) Signed-off-by: Gabriel Marinho --- vllm/entrypoints/openai/protocol.py | 59 ++++--------------- vllm/entrypoints/openai/serving_chat.py | 18 ++++-- vllm/entrypoints/openai/serving_completion.py | 16 +++-- vllm/entrypoints/utils.py | 22 ++++++- vllm/platforms/interface.py | 4 ++ 5 files changed, 59 insertions(+), 60 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3df11db33384b..93d9c588d8d28 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel): logit_bias: Optional[dict[str, float]] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 - # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens: Optional[int] = Field( default=None, deprecated= @@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel): } def to_beam_search_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None - ) -> BeamSearchParams: - # TODO(#9845): remove max_tokens when field is removed from OpenAI API - max_tokens = self.max_completion_tokens or self.max_tokens + self, max_tokens: int, + default_sampling_params: dict) -> BeamSearchParams: - if default_sampling_params is None: - default_sampling_params = {} n = self.n if self.n is not None else 1 - - # Use minimum of context window, user request & server limit. - max_tokens = min( - val for val in (default_max_tokens, max_tokens, - default_sampling_params.get("max_tokens", None)) - if val is not None) - if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) @@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel): def to_sampling_params( self, - default_max_tokens: int, + max_tokens: int, logits_processor_pattern: Optional[str], - default_sampling_params: Optional[dict] = None, + default_sampling_params: dict, ) -> SamplingParams: - # TODO(#9845): remove max_tokens when field is removed from OpenAI API - max_tokens = self.max_completion_tokens or self.max_tokens - - if default_sampling_params is None: - default_sampling_params = {} - - # Use minimum of context window, user request & server limit. - max_tokens = min( - val for val in (default_max_tokens, max_tokens, - default_sampling_params.get("max_tokens", None)) - if val is not None) # Default parameters if (repetition_penalty := self.repetition_penalty) is None: @@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel): } def to_beam_search_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None + self, + max_tokens: int, + default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: - max_tokens = self.max_tokens if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 - # Use minimum of context window, user request & server limit. - max_tokens = min( - val for val in (default_max_tokens, max_tokens, - default_sampling_params.get("max_tokens", None)) - if val is not None) - if (temperature := self.temperature) is None: temperature = default_sampling_params.get("temperature", 1.0) @@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel): def to_sampling_params( self, - default_max_tokens: int, + max_tokens: int, logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: - max_tokens = self.max_tokens if default_sampling_params is None: default_sampling_params = {} - # Use minimum of context window, user request & server limit. - max_tokens = min( - val for val in (default_max_tokens, max_tokens, - default_sampling_params.get("max_tokens", None)) - if val is not None) - # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel): self, default_max_tokens: int, default_sampling_params: Optional[dict] = None) -> SamplingParams: - # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens if default_sampling_params is None: @@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel): self, default_max_tokens: int, default_sampling_params: Optional[dict] = None) -> SamplingParams: - # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens if default_sampling_params is None: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 299ade4e4d7d6..a802fbc3865f9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) +from vllm.entrypoints.utils import get_max_tokens from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing): try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) + + if self.default_sampling_params is None: + self.default_sampling_params = {} + + max_tokens = get_max_tokens( + max_model_len=self.max_model_len, + request=request, + input_length=len(engine_prompt["prompt_token_ids"]), + default_sampling_params=self.default_sampling_params) + if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - default_max_tokens, - self.model_config.logits_processor_pattern, + max_tokens, self.model_config.logits_processor_pattern, self.default_sampling_params) self._log_inputs(request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8171b491aafcc..6c9c29b714457 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, is_text_tokens_prompt) # yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) from vllm.logger import init_logger @@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing): input_length = len(engine_prompt["prompt_token_ids"]) else: assert_never(engine_prompt) - default_max_tokens = self.max_model_len - input_length + + if self.default_sampling_params is None: + self.default_sampling_params = {} + + max_tokens = get_max_tokens( + max_model_len=self.max_model_len, + request=request, + input_length=input_length, + default_sampling_params=self.default_sampling_params) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - default_max_tokens, - self.model_config.logits_processor_pattern, + max_tokens, self.model_config.logits_processor_pattern, self.default_sampling_params) request_id_item = f"{request_id}-{i}" diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 5b085e5b79478..423b99dbe565c 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,13 +5,17 @@ import argparse import asyncio import functools import os -from typing import Any, Optional +import sys +from typing import Any, Optional, Union from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.background import BackgroundTask, BackgroundTasks +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -181,7 +185,6 @@ def _validate_truncation_size( def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, subcommand_name: list[str]): - import sys # Only handle --help= for the current subcommand. # Since subparser_init() runs for all subcommands during CLI setup, @@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, print(f"\nNo group or parameter matching '{search_keyword}'") print("Tip: use `--help=listgroup` to view all groups.") sys.exit(1) + + +def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest, + CompletionRequest], + input_length: int, default_sampling_params: dict) -> int: + + max_tokens = getattr(request, "max_completion_tokens", + None) or request.max_tokens + default_max_tokens = max_model_len - input_length + max_output_tokens = current_platform.get_max_output_tokens(input_length) + + return min(val + for val in (default_max_tokens, max_tokens, max_output_tokens, + default_sampling_params.get("max_tokens")) + if val is not None) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0f08bf986333b..567d5cbf503fe 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -4,6 +4,7 @@ import enum import os import platform import random +import sys from datetime import timedelta from platform import uname from typing import TYPE_CHECKING, NamedTuple, Optional, Union @@ -164,6 +165,9 @@ class Platform: def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT + def get_max_output_tokens(self, prompt_len: int) -> int: + return sys.maxsize + def is_cuda_alike(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)