mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 18:44:28 +08:00
[Platform] Add custom default max tokens (#18557)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
This commit is contained in:
parent
7e1665b089
commit
a4113b035c
@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
logit_bias: Optional[dict[str, float]] = None
|
logit_bias: Optional[dict[str, float]] = None
|
||||||
logprobs: Optional[bool] = False
|
logprobs: Optional[bool] = False
|
||||||
top_logprobs: Optional[int] = 0
|
top_logprobs: Optional[int] = 0
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
deprecated=
|
deprecated=
|
||||||
@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def to_beam_search_params(
|
def to_beam_search_params(
|
||||||
self,
|
self, max_tokens: int,
|
||||||
default_max_tokens: int,
|
default_sampling_params: dict) -> BeamSearchParams:
|
||||||
default_sampling_params: Optional[dict] = None
|
|
||||||
) -> BeamSearchParams:
|
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
|
||||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
|
||||||
default_sampling_params = {}
|
|
||||||
n = self.n if self.n is not None else 1
|
n = self.n if self.n is not None else 1
|
||||||
|
|
||||||
# Use minimum of context window, user request & server limit.
|
|
||||||
max_tokens = min(
|
|
||||||
val for val in (default_max_tokens, max_tokens,
|
|
||||||
default_sampling_params.get("max_tokens", None))
|
|
||||||
if val is not None)
|
|
||||||
|
|
||||||
if (temperature := self.temperature) is None:
|
if (temperature := self.temperature) is None:
|
||||||
temperature = default_sampling_params.get(
|
temperature = default_sampling_params.get(
|
||||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
self,
|
self,
|
||||||
default_max_tokens: int,
|
max_tokens: int,
|
||||||
logits_processor_pattern: Optional[str],
|
logits_processor_pattern: Optional[str],
|
||||||
default_sampling_params: Optional[dict] = None,
|
default_sampling_params: dict,
|
||||||
) -> SamplingParams:
|
) -> SamplingParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
|
||||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
|
||||||
default_sampling_params = {}
|
|
||||||
|
|
||||||
# Use minimum of context window, user request & server limit.
|
|
||||||
max_tokens = min(
|
|
||||||
val for val in (default_max_tokens, max_tokens,
|
|
||||||
default_sampling_params.get("max_tokens", None))
|
|
||||||
if val is not None)
|
|
||||||
|
|
||||||
# Default parameters
|
# Default parameters
|
||||||
if (repetition_penalty := self.repetition_penalty) is None:
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def to_beam_search_params(
|
def to_beam_search_params(
|
||||||
self,
|
self,
|
||||||
default_max_tokens: int,
|
max_tokens: int,
|
||||||
default_sampling_params: Optional[dict] = None
|
default_sampling_params: Optional[dict] = None,
|
||||||
) -> BeamSearchParams:
|
) -> BeamSearchParams:
|
||||||
max_tokens = self.max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
if default_sampling_params is None:
|
||||||
default_sampling_params = {}
|
default_sampling_params = {}
|
||||||
n = self.n if self.n is not None else 1
|
n = self.n if self.n is not None else 1
|
||||||
|
|
||||||
# Use minimum of context window, user request & server limit.
|
|
||||||
max_tokens = min(
|
|
||||||
val for val in (default_max_tokens, max_tokens,
|
|
||||||
default_sampling_params.get("max_tokens", None))
|
|
||||||
if val is not None)
|
|
||||||
|
|
||||||
if (temperature := self.temperature) is None:
|
if (temperature := self.temperature) is None:
|
||||||
temperature = default_sampling_params.get("temperature", 1.0)
|
temperature = default_sampling_params.get("temperature", 1.0)
|
||||||
|
|
||||||
@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
self,
|
self,
|
||||||
default_max_tokens: int,
|
max_tokens: int,
|
||||||
logits_processor_pattern: Optional[str],
|
logits_processor_pattern: Optional[str],
|
||||||
default_sampling_params: Optional[dict] = None,
|
default_sampling_params: Optional[dict] = None,
|
||||||
) -> SamplingParams:
|
) -> SamplingParams:
|
||||||
max_tokens = self.max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
if default_sampling_params is None:
|
||||||
default_sampling_params = {}
|
default_sampling_params = {}
|
||||||
|
|
||||||
# Use minimum of context window, user request & server limit.
|
|
||||||
max_tokens = min(
|
|
||||||
val for val in (default_max_tokens, max_tokens,
|
|
||||||
default_sampling_params.get("max_tokens", None))
|
|
||||||
if val is not None)
|
|
||||||
|
|
||||||
# Default parameters
|
# Default parameters
|
||||||
if (repetition_penalty := self.repetition_penalty) is None:
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
repetition_penalty = default_sampling_params.get(
|
repetition_penalty = default_sampling_params.get(
|
||||||
@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
self,
|
self,
|
||||||
default_max_tokens: int,
|
default_max_tokens: int,
|
||||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
if default_sampling_params is None:
|
if default_sampling_params is None:
|
||||||
@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
|
|||||||
self,
|
self,
|
||||||
default_max_tokens: int,
|
default_max_tokens: int,
|
||||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
if default_sampling_params is None:
|
if default_sampling_params is None:
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
MistralToolCall)
|
MistralToolCall)
|
||||||
|
from vllm.entrypoints.utils import get_max_tokens
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
|
||||||
engine_prompt["prompt_token_ids"])
|
if self.default_sampling_params is None:
|
||||||
|
self.default_sampling_params = {}
|
||||||
|
|
||||||
|
max_tokens = get_max_tokens(
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
request=request,
|
||||||
|
input_length=len(engine_prompt["prompt_token_ids"]),
|
||||||
|
default_sampling_params=self.default_sampling_params)
|
||||||
|
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens, self.default_sampling_params)
|
max_tokens, self.default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
max_tokens, self.model_config.logits_processor_pattern,
|
||||||
self.model_config.logits_processor_pattern,
|
|
||||||
self.default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
|||||||
is_text_tokens_prompt)
|
is_text_tokens_prompt)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.utils import get_max_tokens
|
||||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||||
is_tokens_prompt)
|
is_tokens_prompt)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
input_length = len(engine_prompt["prompt_token_ids"])
|
input_length = len(engine_prompt["prompt_token_ids"])
|
||||||
else:
|
else:
|
||||||
assert_never(engine_prompt)
|
assert_never(engine_prompt)
|
||||||
default_max_tokens = self.max_model_len - input_length
|
|
||||||
|
if self.default_sampling_params is None:
|
||||||
|
self.default_sampling_params = {}
|
||||||
|
|
||||||
|
max_tokens = get_max_tokens(
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
request=request,
|
||||||
|
input_length=input_length,
|
||||||
|
default_sampling_params=self.default_sampling_params)
|
||||||
|
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens, self.default_sampling_params)
|
max_tokens, self.default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
max_tokens, self.model_config.logits_processor_pattern,
|
||||||
self.model_config.logits_processor_pattern,
|
|
||||||
self.default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|||||||
@ -5,13 +5,17 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
import sys
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from starlette.background import BackgroundTask, BackgroundTasks
|
from starlette.background import BackgroundTask, BackgroundTasks
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
CompletionRequest)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -181,7 +185,6 @@ def _validate_truncation_size(
|
|||||||
|
|
||||||
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
|
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
|
||||||
subcommand_name: list[str]):
|
subcommand_name: list[str]):
|
||||||
import sys
|
|
||||||
|
|
||||||
# Only handle --help=<keyword> for the current subcommand.
|
# Only handle --help=<keyword> for the current subcommand.
|
||||||
# Since subparser_init() runs for all subcommands during CLI setup,
|
# Since subparser_init() runs for all subcommands during CLI setup,
|
||||||
@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
|
|||||||
print(f"\nNo group or parameter matching '{search_keyword}'")
|
print(f"\nNo group or parameter matching '{search_keyword}'")
|
||||||
print("Tip: use `--help=listgroup` to view all groups.")
|
print("Tip: use `--help=listgroup` to view all groups.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
|
||||||
|
CompletionRequest],
|
||||||
|
input_length: int, default_sampling_params: dict) -> int:
|
||||||
|
|
||||||
|
max_tokens = getattr(request, "max_completion_tokens",
|
||||||
|
None) or request.max_tokens
|
||||||
|
default_max_tokens = max_model_len - input_length
|
||||||
|
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||||
|
|
||||||
|
return min(val
|
||||||
|
for val in (default_max_tokens, max_tokens, max_output_tokens,
|
||||||
|
default_sampling_params.get("max_tokens"))
|
||||||
|
if val is not None)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import enum
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||||
@ -164,6 +165,9 @@ class Platform:
|
|||||||
def is_out_of_tree(self) -> bool:
|
def is_out_of_tree(self) -> bool:
|
||||||
return self._enum == PlatformEnum.OOT
|
return self._enum == PlatformEnum.OOT
|
||||||
|
|
||||||
|
def get_max_output_tokens(self, prompt_len: int) -> int:
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
def is_cuda_alike(self) -> bool:
|
def is_cuda_alike(self) -> bool:
|
||||||
"""Stateless version of [torch.cuda.is_available][]."""
|
"""Stateless version of [torch.cuda.is_available][]."""
|
||||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user