mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 07:43:05 +08:00
[Platform] Add custom default max tokens (#18557)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
This commit is contained in:
parent
7e1665b089
commit
a4113b035c
@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
logit_bias: Optional[dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
deprecated=
|
||||
@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
self, max_tokens: int,
|
||||
default_sampling_params: dict) -> BeamSearchParams:
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
default_sampling_params: dict,
|
||||
) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
self,
|
||||
max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=len(engine_prompt["prompt_token_ids"]),
|
||||
default_sampling_params=self.default_sampling_params)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
max_tokens, self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
is_text_tokens_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
default_max_tokens = self.max_model_len - input_length
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
max_tokens, self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
@ -5,13 +5,17 @@ import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
import sys
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -181,7 +185,6 @@ def _validate_truncation_size(
|
||||
|
||||
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
|
||||
subcommand_name: list[str]):
|
||||
import sys
|
||||
|
||||
# Only handle --help=<keyword> for the current subcommand.
|
||||
# Since subparser_init() runs for all subcommands during CLI setup,
|
||||
@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
|
||||
print(f"\nNo group or parameter matching '{search_keyword}'")
|
||||
print("Tip: use `--help=listgroup` to view all groups.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
|
||||
CompletionRequest],
|
||||
input_length: int, default_sampling_params: dict) -> int:
|
||||
|
||||
max_tokens = getattr(request, "max_completion_tokens",
|
||||
None) or request.max_tokens
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
return min(val
|
||||
for val in (default_max_tokens, max_tokens, max_output_tokens,
|
||||
default_sampling_params.get("max_tokens"))
|
||||
if val is not None)
|
||||
|
||||
@ -4,6 +4,7 @@ import enum
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from platform import uname
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
@ -164,6 +165,9 @@ class Platform:
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||
return sys.maxsize
|
||||
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of [torch.cuda.is_available][]."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user