[Platform] Add custom default max tokens (#18557)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
This commit is contained in:
Gabriel Marinho 2025-07-03 23:50:17 -03:00 committed by GitHub
parent 7e1665b089
commit a4113b035c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 60 deletions

View File

@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[dict[str, float]] = None logit_bias: Optional[dict[str, float]] = None
logprobs: Optional[bool] = False logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0 top_logprobs: Optional[int] = 0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default=None, default=None,
deprecated= deprecated=
@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
} }
def to_beam_search_params( def to_beam_search_params(
self, self, max_tokens: int,
default_max_tokens: int, default_sampling_params: dict) -> BeamSearchParams:
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None: if (temperature := self.temperature) is None:
temperature = default_sampling_params.get( temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None, default_sampling_params: dict,
) -> SamplingParams: ) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters # Default parameters
if (repetition_penalty := self.repetition_penalty) is None: if (repetition_penalty := self.repetition_penalty) is None:
@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
} }
def to_beam_search_params( def to_beam_search_params(
self, self,
default_max_tokens: int, max_tokens: int,
default_sampling_params: Optional[dict] = None default_sampling_params: Optional[dict] = None,
) -> BeamSearchParams: ) -> BeamSearchParams:
max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
default_sampling_params = {} default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None: if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0) temperature = default_sampling_params.get("temperature", 1.0)
@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None, default_sampling_params: Optional[dict] = None,
) -> SamplingParams: ) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
default_sampling_params = {} default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters # Default parameters
if (repetition_penalty := self.repetition_penalty) is None: if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get( repetition_penalty = default_sampling_params.get(
@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
self, self,
default_max_tokens: int, default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None: if default_sampling_params is None:
@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
self, self,
default_max_tokens: int, default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None: if default_sampling_params is None:

View File

@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params) max_tokens, self.default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, max_tokens, self.model_config.logits_processor_pattern,
self.model_config.logits_processor_pattern,
self.default_sampling_params) self.default_sampling_params)
self._log_inputs(request_id, self._log_inputs(request_id,

View File

@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
is_text_tokens_prompt) is_text_tokens_prompt)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt) is_tokens_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
input_length = len(engine_prompt["prompt_token_ids"]) input_length = len(engine_prompt["prompt_token_ids"])
else: else:
assert_never(engine_prompt) assert_never(engine_prompt)
default_max_tokens = self.max_model_len - input_length
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params) max_tokens, self.default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, max_tokens, self.model_config.logits_processor_pattern,
self.model_config.logits_processor_pattern,
self.default_sampling_params) self.default_sampling_params)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"

View File

@ -5,13 +5,17 @@ import argparse
import asyncio import asyncio
import functools import functools
import os import os
from typing import Any, Optional import sys
from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks from starlette.background import BackgroundTask, BackgroundTasks
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -181,7 +185,6 @@ def _validate_truncation_size(
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
subcommand_name: list[str]): subcommand_name: list[str]):
import sys
# Only handle --help=<keyword> for the current subcommand. # Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup, # Since subparser_init() runs for all subcommands during CLI setup,
@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print(f"\nNo group or parameter matching '{search_keyword}'") print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.") print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1) sys.exit(1)
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
CompletionRequest],
input_length: int, default_sampling_params: dict) -> int:
max_tokens = getattr(request, "max_completion_tokens",
None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
return min(val
for val in (default_max_tokens, max_tokens, max_output_tokens,
default_sampling_params.get("max_tokens"))
if val is not None)

View File

@ -4,6 +4,7 @@ import enum
import os import os
import platform import platform
import random import random
import sys
from datetime import timedelta from datetime import timedelta
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
@ -164,6 +165,9 @@ class Platform:
def is_out_of_tree(self) -> bool: def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT return self._enum == PlatformEnum.OOT
def get_max_output_tokens(self, prompt_len: int) -> int:
return sys.maxsize
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of [torch.cuda.is_available][].""" """Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)