[Platform] Add custom default max tokens (#18557)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
This commit is contained in:
Gabriel Marinho 2025-07-03 23:50:17 -03:00 committed by GitHub
parent 7e1665b089
commit a4113b035c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 60 deletions

View File

@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated=
@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
self, max_tokens: int,
default_sampling_params: dict) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params(
self,
default_max_tokens: int,
max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
default_sampling_params: dict,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
self,
max_tokens: int,
default_sampling_params: Optional[dict] = None,
) -> BeamSearchParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
def to_sampling_params(
self,
default_max_tokens: int,
max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:
@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:

View File

@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)
self._log_inputs(request_id,

View File

@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
is_text_tokens_prompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
input_length = len(engine_prompt["prompt_token_ids"])
else:
assert_never(engine_prompt)
default_max_tokens = self.max_model_len - input_length
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)
request_id_item = f"{request_id}-{i}"

View File

@ -5,13 +5,17 @@ import argparse
import asyncio
import functools
import os
from typing import Any, Optional
import sys
from typing import Any, Optional, Union
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -181,7 +185,6 @@ def _validate_truncation_size(
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
subcommand_name: list[str]):
import sys
# Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup,
@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1)
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
CompletionRequest],
input_length: int, default_sampling_params: dict) -> int:
max_tokens = getattr(request, "max_completion_tokens",
None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
return min(val
for val in (default_max_tokens, max_tokens, max_output_tokens,
default_sampling_params.get("max_tokens"))
if val is not None)

View File

@ -4,6 +4,7 @@ import enum
import os
import platform
import random
import sys
from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
@ -164,6 +165,9 @@ class Platform:
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT
def get_max_output_tokens(self, prompt_len: int) -> int:
return sys.maxsize
def is_cuda_alike(self) -> bool:
"""Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)