mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 06:45:01 +08:00
[Feature] Add load generation config from model (#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
98356735ac
commit
5aef49806d
30
examples/offline_inference_with_default_generation_config.py
Normal file
30
examples/offline_inference_with_default_generation_config.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an LLM with built-in default generation config.
|
||||||
|
# The generation config is set to None by default to keep
|
||||||
|
# the behavior consistent with the previous version.
|
||||||
|
# If you want to use the default generation config from the model,
|
||||||
|
# you should set the generation_config to "auto".
|
||||||
|
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")
|
||||||
|
|
||||||
|
# Load the default sampling parameters from the model.
|
||||||
|
sampling_params = llm.get_default_sampling_params()
|
||||||
|
# Modify the sampling parameters if needed.
|
||||||
|
sampling_params.temperature = 0.5
|
||||||
|
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from vllm.config import MultiModalConfig
|
from vllm.config import MultiModalConfig
|
||||||
@ -31,6 +32,10 @@ class MockModelConfig:
|
|||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
hf_config = MockHFConfig()
|
hf_config = MockHFConfig()
|
||||||
logits_processor_pattern = None
|
logits_processor_pattern = None
|
||||||
|
diff_sampling_param: Optional[dict] = None
|
||||||
|
|
||||||
|
def get_diff_sampling_param(self):
|
||||||
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_serving_chat_could_load_correct_generation_config():
|
||||||
|
|
||||||
|
mock_model_config = MockModelConfig()
|
||||||
|
mock_model_config.diff_sampling_param = {
|
||||||
|
"temperature": 0.5,
|
||||||
|
"repetition_penalty": 1.05
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
|
||||||
|
# Initialize the serving chat
|
||||||
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
|
mock_model_config,
|
||||||
|
BASE_MODEL_PATHS,
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
request_logger=None)
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}],
|
||||||
|
guided_decoding_backend="outlines",
|
||||||
|
)
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
|
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||||
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
|
|
||||||
|
# Test the param when user set it
|
||||||
|
req.temperature = 0.1
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
|
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||||
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
|
|
||||||
|
# Test When temperature==0.0
|
||||||
|
req.temperature = 0.0
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
|
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||||
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
|
|||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
get_hf_text_config, get_pooling_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||||
|
try_get_generation_config, uses_mrope)
|
||||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||||
get_cpu_memory, print_warning_once, random_uuid,
|
get_cpu_memory, print_warning_once, random_uuid,
|
||||||
resolve_obj_by_qualname)
|
resolve_obj_by_qualname)
|
||||||
@ -160,6 +161,7 @@ class ModelConfig:
|
|||||||
logits processor qualified names that can be passed with the
|
logits processor qualified names that can be passed with the
|
||||||
`logits_processors` extra completion argument. Defaults to None,
|
`logits_processors` extra completion argument. Defaults to None,
|
||||||
which allows no processors.
|
which allows no processors.
|
||||||
|
generation_config: Configuration parameter file for generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -218,7 +220,8 @@ class ModelConfig:
|
|||||||
disable_mm_preprocessor_cache: bool = False,
|
disable_mm_preprocessor_cache: bool = False,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||||
logits_processor_pattern: Optional[str] = None) -> None:
|
logits_processor_pattern: Optional[str] = None,
|
||||||
|
generation_config: Optional[str] = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
@ -348,6 +351,8 @@ class ModelConfig:
|
|||||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||||
self.logits_processor_pattern = logits_processor_pattern
|
self.logits_processor_pattern = logits_processor_pattern
|
||||||
|
|
||||||
|
self.generation_config = generation_config
|
||||||
|
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
self._verify_bnb_config()
|
self._verify_bnb_config()
|
||||||
@ -813,6 +818,56 @@ class ModelConfig:
|
|||||||
|
|
||||||
return self.multimodal_config
|
return self.multimodal_config
|
||||||
|
|
||||||
|
def try_get_generation_config(self) -> Dict[str, Any]:
|
||||||
|
if self.generation_config is None or self.generation_config == "auto":
|
||||||
|
config = try_get_generation_config(
|
||||||
|
self.model,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
revision=self.revision,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = try_get_generation_config(
|
||||||
|
self.generation_config,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return config.to_diff_dict()
|
||||||
|
|
||||||
|
def get_diff_sampling_param(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
This method returns a dictionary containing the parameters
|
||||||
|
that differ from the default sampling parameters, but only
|
||||||
|
if `generation_config` is set. If `generation_config` is not
|
||||||
|
set, an empty dictionary is returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary with the differing sampling
|
||||||
|
parameters if `generation_config` is set, otherwise an
|
||||||
|
empty dictionary.
|
||||||
|
"""
|
||||||
|
if self.generation_config is None:
|
||||||
|
# When generation_config is not set
|
||||||
|
return {}
|
||||||
|
config = self.try_get_generation_config()
|
||||||
|
available_params = [
|
||||||
|
"repetition_penalty",
|
||||||
|
"temperature",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"min_p",
|
||||||
|
]
|
||||||
|
if any(p in config for p in available_params):
|
||||||
|
diff_sampling_param = {
|
||||||
|
p: config.get(p)
|
||||||
|
for p in available_params if config.get(p) is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
diff_sampling_param = {}
|
||||||
|
return diff_sampling_param
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_encoder_decoder(self) -> bool:
|
def is_encoder_decoder(self) -> bool:
|
||||||
"""Extract the HF encoder/decoder model flag."""
|
"""Extract the HF encoder/decoder model flag."""
|
||||||
|
|||||||
@ -197,6 +197,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
|
||||||
|
generation_config: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
@ -942,6 +944,16 @@ class EngineArgs:
|
|||||||
default="auto",
|
default="auto",
|
||||||
help='The worker class to use for distributed execution.')
|
help='The worker class to use for distributed execution.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generation-config",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The folder path to the generation config. "
|
||||||
|
"Defaults to None, will use the default generation config in vLLM. "
|
||||||
|
"If set to 'auto', the generation config will be automatically "
|
||||||
|
"loaded from model. If set to a folder path, the generation config "
|
||||||
|
"will be loaded from the specified folder path.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -985,7 +997,8 @@ class EngineArgs:
|
|||||||
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
||||||
override_neuron_config=self.override_neuron_config,
|
override_neuron_config=self.override_neuron_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
logits_processor_pattern=self.logits_processor_pattern)
|
logits_processor_pattern=self.logits_processor_pattern,
|
||||||
|
generation_config=self.generation_config)
|
||||||
|
|
||||||
def create_load_config(self) -> LoadConfig:
|
def create_load_config(self) -> LoadConfig:
|
||||||
return LoadConfig(
|
return LoadConfig(
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from collections import deque
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||||
Iterable, List, Mapping, NamedTuple, Optional)
|
List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, Union, cast, overload
|
from typing import Set, Type, Union, cast, overload
|
||||||
|
|
||||||
@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
|||||||
SequenceGroupOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceStatus)
|
||||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||||
init_tracer)
|
init_tracer)
|
||||||
from vllm.transformers_utils.config import try_get_generation_config
|
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group import (
|
||||||
@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
|
||||||
config = try_get_generation_config(
|
|
||||||
model_config.model,
|
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
revision=model_config.revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return config.to_diff_dict()
|
|
||||||
|
|
||||||
|
|
||||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||||
|
|
||||||
@ -274,8 +259,8 @@ class LLMEngine:
|
|||||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
self.generation_config_fields = _load_generation_config_dict(
|
self.generation_config_fields = (
|
||||||
self.model_config)
|
self.model_config.try_get_generation_config())
|
||||||
|
|
||||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
@ -258,6 +258,13 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
def get_default_sampling_params(self) -> SamplingParams:
|
||||||
|
diff_sampling_param = (
|
||||||
|
self.llm_engine.model_config.get_diff_sampling_param())
|
||||||
|
if diff_sampling_param:
|
||||||
|
return SamplingParams.from_optional(**diff_sampling_param)
|
||||||
|
return SamplingParams()
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -441,7 +448,7 @@ class LLM:
|
|||||||
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
# Use default sampling params.
|
# Use default sampling params.
|
||||||
sampling_params = SamplingParams()
|
sampling_params = self.get_default_sampling_params()
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=parsed_prompts,
|
prompts=parsed_prompts,
|
||||||
|
|||||||
@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = None
|
||||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||||
@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
# doc: begin-chat-completion-sampling-params
|
# doc: begin-chat-completion-sampling-params
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
top_k: int = -1
|
top_k: Optional[int] = None
|
||||||
min_p: float = 0.0
|
min_p: Optional[float] = None
|
||||||
repetition_penalty: float = 1.0
|
repetition_penalty: Optional[float] = None
|
||||||
length_penalty: float = 1.0
|
length_penalty: float = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
include_stop_str_in_output: bool = False
|
include_stop_str_in_output: bool = False
|
||||||
@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
def to_beam_search_params(self,
|
# Default sampling parameters for chat completion requests
|
||||||
default_max_tokens: int) -> BeamSearchParams:
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"min_p": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_beam_search_params(
|
||||||
|
self,
|
||||||
|
default_max_tokens: int,
|
||||||
|
default_sampling_params: Optional[dict] = None
|
||||||
|
) -> BeamSearchParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
n = self.n if self.n is not None else 1
|
n = self.n if self.n is not None else 1
|
||||||
temperature = self.temperature if self.temperature is not None else 0.0
|
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
|
||||||
return BeamSearchParams(
|
return BeamSearchParams(
|
||||||
beam_width=n,
|
beam_width=n,
|
||||||
@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
self, default_max_tokens: int,
|
self,
|
||||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
default_max_tokens: int,
|
||||||
|
logits_processor_pattern: Optional[str],
|
||||||
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
# Default parameters
|
||||||
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
|
repetition_penalty = default_sampling_params.get(
|
||||||
|
"repetition_penalty",
|
||||||
|
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||||
|
)
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
if (top_p := self.top_p) is None:
|
||||||
|
top_p = default_sampling_params.get(
|
||||||
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
|
if (top_k := self.top_k) is None:
|
||||||
|
top_k = default_sampling_params.get(
|
||||||
|
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||||
|
if (min_p := self.min_p) is None:
|
||||||
|
min_p = default_sampling_params.get(
|
||||||
|
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
|
|
||||||
prompt_logprobs = self.prompt_logprobs
|
prompt_logprobs = self.prompt_logprobs
|
||||||
if prompt_logprobs is None and self.echo:
|
if prompt_logprobs is None and self.echo:
|
||||||
prompt_logprobs = self.top_logprobs
|
prompt_logprobs = self.top_logprobs
|
||||||
@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
temperature=self.temperature,
|
temperature=temperature,
|
||||||
top_p=self.top_p,
|
top_p=top_p,
|
||||||
top_k=self.top_k,
|
top_k=top_k,
|
||||||
min_p=self.min_p,
|
min_p=min_p,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# doc: begin-completion-sampling-params
|
# doc: begin-completion-sampling-params
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
top_k: int = -1
|
top_k: Optional[int] = None
|
||||||
min_p: float = 0.0
|
min_p: Optional[float] = None
|
||||||
repetition_penalty: float = 1.0
|
repetition_penalty: Optional[float] = None
|
||||||
length_penalty: float = 1.0
|
length_penalty: float = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
include_stop_str_in_output: bool = False
|
include_stop_str_in_output: bool = False
|
||||||
@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
def to_beam_search_params(self,
|
# Default sampling parameters for completion requests
|
||||||
default_max_tokens: int) -> BeamSearchParams:
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"min_p": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_beam_search_params(
|
||||||
|
self,
|
||||||
|
default_max_tokens: int,
|
||||||
|
default_sampling_params: Optional[dict] = None
|
||||||
|
) -> BeamSearchParams:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
n = self.n if self.n is not None else 1
|
n = self.n if self.n is not None else 1
|
||||||
temperature = self.temperature if self.temperature is not None else 0.0
|
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get("temperature", 1.0)
|
||||||
|
|
||||||
return BeamSearchParams(
|
return BeamSearchParams(
|
||||||
beam_width=n,
|
beam_width=n,
|
||||||
@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
self, default_max_tokens: int,
|
self,
|
||||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
default_max_tokens: int,
|
||||||
|
logits_processor_pattern: Optional[str],
|
||||||
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
# Default parameters
|
||||||
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
|
repetition_penalty = default_sampling_params.get(
|
||||||
|
"repetition_penalty",
|
||||||
|
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||||
|
)
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
if (top_p := self.top_p) is None:
|
||||||
|
top_p = default_sampling_params.get(
|
||||||
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
|
if (top_k := self.top_k) is None:
|
||||||
|
top_k = default_sampling_params.get(
|
||||||
|
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||||
|
if (min_p := self.min_p) is None:
|
||||||
|
min_p = default_sampling_params.get(
|
||||||
|
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
|
|
||||||
prompt_logprobs = self.prompt_logprobs
|
prompt_logprobs = self.prompt_logprobs
|
||||||
if prompt_logprobs is None and self.echo:
|
if prompt_logprobs is None and self.echo:
|
||||||
prompt_logprobs = self.logprobs
|
prompt_logprobs = self.logprobs
|
||||||
@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
temperature=self.temperature,
|
temperature=temperature,
|
||||||
top_p=self.top_p,
|
top_p=top_p,
|
||||||
top_k=self.top_k,
|
top_k=top_k,
|
||||||
min_p=self.min_p,
|
min_p=min_p,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
|||||||
@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
"been registered") from e
|
"been registered") from e
|
||||||
|
|
||||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||||
|
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||||
|
if diff_sampling_param:
|
||||||
|
logger.info("Overwriting default chat sampling param with: %s",
|
||||||
|
diff_sampling_param)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
|
# Build default sampling params
|
||||||
|
default_sampling_params = (
|
||||||
|
self.model_config.get_diff_sampling_param())
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens)
|
default_max_tokens, default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
default_max_tokens,
|
||||||
self.model_config.logits_processor_pattern)
|
self.model_config.logits_processor_pattern,
|
||||||
|
default_sampling_params)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
request_prompts[i],
|
request_prompts[i],
|
||||||
|
|||||||
@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_adapters=prompt_adapters,
|
prompt_adapters=prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||||
|
if diff_sampling_param:
|
||||||
|
logger.info(
|
||||||
|
"Overwriting default completion sampling param with: %s",
|
||||||
|
diff_sampling_param)
|
||||||
|
|
||||||
async def create_completion(
|
async def create_completion(
|
||||||
self,
|
self,
|
||||||
@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
|
# Build default sampling params
|
||||||
|
default_sampling_params = (
|
||||||
|
self.model_config.get_diff_sampling_param())
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens)
|
default_max_tokens, default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
default_max_tokens,
|
||||||
self.model_config.logits_processor_pattern)
|
self.model_config.logits_processor_pattern,
|
||||||
|
default_sampling_params)
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
from typing import Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.config import try_get_generation_config
|
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
||||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||||
@ -34,8 +33,8 @@ class Processor:
|
|||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.generation_config_fields = _load_generation_config_dict(
|
self.generation_config_fields = model_config.try_get_generation_config(
|
||||||
model_config)
|
)
|
||||||
self.input_preprocessor = InputPreprocessor(model_config,
|
self.input_preprocessor = InputPreprocessor(model_config,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
mm_registry)
|
mm_registry)
|
||||||
@ -181,16 +180,3 @@ class Processor:
|
|||||||
# TODO: Find out how many placeholder tokens are there so we can
|
# TODO: Find out how many placeholder tokens are there so we can
|
||||||
# check that chunked prefill does not truncate them
|
# check that chunked prefill does not truncate them
|
||||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||||
|
|
||||||
|
|
||||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
|
||||||
config = try_get_generation_config(
|
|
||||||
model_config.model,
|
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
revision=model_config.revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return config.to_diff_dict()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user