[Feature] Add load generation config from model (#11164)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
Signed-off-by: Yanyi Liu <wolfsonliu@163.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yanyi Liu 2024-12-19 18:50:38 +08:00 committed by GitHub
parent 98356735ac
commit 5aef49806d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 307 additions and 74 deletions

View File

@ -0,0 +1,30 @@
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM with built-in default generation config.
# The generation config is set to None by default to keep
# the behavior consistent with the previous version.
# If you want to use the default generation config from the model,
# you should set the generation_config to "auto".
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")
# Load the default sampling parameters from the model.
sampling_params = llm.get_default_sampling_params()
# Modify the sampling parameters if needed.
sampling_params.temperature = 0.5
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
@ -31,6 +32,10 @@ class MockModelConfig:
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass @dataclass
@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
asyncio.run(serving_chat.create_chat_completion(req)) asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"temperature": 0.5,
"repetition_penalty": 1.05
}
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test the param when user set it
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test When temperature==0.0
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

View File

@ -27,7 +27,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config, ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid, get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname) resolve_obj_by_qualname)
@ -160,6 +161,7 @@ class ModelConfig:
logits processor qualified names that can be passed with the logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None, `logits_processors` extra completion argument. Defaults to None,
which allows no processors. which allows no processors.
generation_config: Configuration parameter file for generation.
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -218,7 +220,8 @@ class ModelConfig:
disable_mm_preprocessor_cache: bool = False, disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None, override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None: logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
@ -348,6 +351,8 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config(override_pooler_config) self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
@ -813,6 +818,56 @@ class ModelConfig:
return self.multimodal_config return self.multimodal_config
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
else:
config = try_get_generation_config(
self.generation_config,
trust_remote_code=self.trust_remote_code,
)
if config is None:
return {}
return config.to_diff_dict()
def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
else:
diff_sampling_param = {}
return diff_sampling_param
@property @property
def is_encoder_decoder(self) -> bool: def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag.""" """Extract the HF encoder/decoder model flag."""

View File

@ -197,6 +197,8 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
@ -942,6 +944,16 @@ class EngineArgs:
default="auto", default="auto",
help='The worker class to use for distributed execution.') help='The worker class to use for distributed execution.')
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
return parser return parser
@classmethod @classmethod
@ -985,7 +997,8 @@ class EngineArgs:
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern) logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
return LoadConfig( return LoadConfig(

View File

@ -5,8 +5,8 @@ from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
Iterable, List, Mapping, NamedTuple, Optional) List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload from typing import Set, Type, Union, cast, overload
@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
SequenceGroupOutput, SequenceStatus) SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@ -274,8 +259,8 @@ class LLMEngine:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request) return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = (
self.model_config) self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config, self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer, self.tokenizer,

View File

@ -258,6 +258,13 @@ class LLM:
else: else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()
@overload @overload
def generate( def generate(
self, self,
@ -441,7 +448,7 @@ class LLM:
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,

View File

@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = None
top_p: Optional[float] = 1.0 top_p: Optional[float] = None
tools: Optional[List[ChatCompletionToolsParam]] = None tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"], tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none" ChatCompletionNamedToolChoiceParam]] = "none"
@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params # doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None best_of: Optional[int] = None
use_beam_search: bool = False use_beam_search: bool = False
top_k: int = -1 top_k: Optional[int] = None
min_p: float = 0.0 min_p: Optional[float] = None
repetition_penalty: float = 1.0 repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_beam_search_params(self, # Default sampling parameters for chat completion requests
default_max_tokens: int) -> BeamSearchParams: _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return BeamSearchParams( return BeamSearchParams(
beam_width=n, beam_width=n,
@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params( def to_sampling_params(
self, default_max_tokens: int, self,
logits_processor_pattern: Optional[str]) -> SamplingParams: default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty, repetition_penalty=repetition_penalty,
temperature=self.temperature, temperature=temperature,
top_p=self.top_p, top_p=top_p,
top_k=self.top_k, top_k=top_k,
min_p=self.min_p, min_p=min_p,
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None suffix: Optional[str] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = None
top_p: Optional[float] = 1.0 top_p: Optional[float] = None
user: Optional[str] = None user: Optional[str] = None
# doc: begin-completion-sampling-params # doc: begin-completion-sampling-params
use_beam_search: bool = False use_beam_search: bool = False
top_k: int = -1 top_k: Optional[int] = None
min_p: float = 0.0 min_p: Optional[float] = None
repetition_penalty: float = 1.0 repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_beam_search_params(self, # Default sampling parameters for completion requests
default_max_tokens: int) -> BeamSearchParams: _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams( return BeamSearchParams(
beam_width=n, beam_width=n,
@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params( def to_sampling_params(
self, default_max_tokens: int, self,
logits_processor_pattern: Optional[str]) -> SamplingParams: default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs prompt_logprobs = self.logprobs
@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty, repetition_penalty=repetition_penalty,
temperature=self.temperature, temperature=temperature,
top_p=self.top_p, top_p=top_p,
top_k=self.top_k, top_k=top_k,
min_p=self.min_p, min_p=min_p,
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,

View File

@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e "been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
async def create_chat_completion( async def create_chat_completion(
self, self,
@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens, default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, default_max_tokens,
self.model_config.logits_processor_pattern) self.model_config.logits_processor_pattern,
default_sampling_params)
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],

View File

@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def create_completion( async def create_completion(
self, self,
@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens, default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, default_max_tokens,
self.model_config.logits_processor_pattern) self.model_config.logits_processor_pattern,
default_sampling_params)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"

View File

@ -1,5 +1,5 @@
import time import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union from typing import Mapping, Optional, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
@ -34,8 +33,8 @@ class Processor:
self.lora_config = lora_config self.lora_config = lora_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = model_config.try_get_generation_config(
model_config) )
self.input_preprocessor = InputPreprocessor(model_config, self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
@ -181,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens # max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()