mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:05:01 +08:00
Default to generation_config from model (#12622)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
3b9c6c6947
commit
47512b3200
@ -20,7 +20,7 @@ NUM_CONCURRENT = 500
|
|||||||
TASK = "gsm8k"
|
TASK = "gsm8k"
|
||||||
FILTER = "exact_match,strict-match"
|
FILTER = "exact_match,strict-match"
|
||||||
RTOL = 0.03
|
RTOL = 0.03
|
||||||
EXPECTED_VALUE = 0.58
|
EXPECTED_VALUE = 0.54
|
||||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||||
MORE_ARGS_LIST = [
|
MORE_ARGS_LIST = [
|
||||||
[], # Default
|
[], # Default
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class MockModelConfig:
|
|||||||
diff_sampling_param: Optional[dict] = None
|
diff_sampling_param: Optional[dict] = None
|
||||||
allowed_local_media_path: str = ""
|
allowed_local_media_path: str = ""
|
||||||
encoder_config = None
|
encoder_config = None
|
||||||
|
generation_config: str = "auto"
|
||||||
|
|
||||||
def get_diff_sampling_param(self):
|
def get_diff_sampling_param(self):
|
||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|||||||
@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
|
|||||||
def test_generation_config_loading():
|
def test_generation_config_loading():
|
||||||
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
|
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
|
||||||
# When set generation_config to None, the default generation config
|
# When set generation_config to "vllm", the default generation config
|
||||||
# will not be loaded.
|
# will not be loaded.
|
||||||
model_config = ModelConfig(model_id,
|
model_config = ModelConfig(model_id,
|
||||||
task="auto",
|
task="auto",
|
||||||
@ -298,7 +298,7 @@ def test_generation_config_loading():
|
|||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
generation_config=None)
|
generation_config="vllm")
|
||||||
assert model_config.get_diff_sampling_param() == {}
|
assert model_config.get_diff_sampling_param() == {}
|
||||||
|
|
||||||
# When set generation_config to "auto", the default generation config
|
# When set generation_config to "auto", the default generation config
|
||||||
@ -340,7 +340,7 @@ def test_generation_config_loading():
|
|||||||
|
|
||||||
assert model_config.get_diff_sampling_param() == override_result
|
assert model_config.get_diff_sampling_param() == override_result
|
||||||
|
|
||||||
# When generation_config is set to None and override_generation_config
|
# When generation_config is set to "vllm" and override_generation_config
|
||||||
# is set, the override_generation_config should be used directly.
|
# is set, the override_generation_config should be used directly.
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model_id,
|
model_id,
|
||||||
@ -350,7 +350,7 @@ def test_generation_config_loading():
|
|||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
generation_config=None,
|
generation_config="vllm",
|
||||||
override_generation_config=override_generation_config)
|
override_generation_config=override_generation_config)
|
||||||
|
|
||||||
assert model_config.get_diff_sampling_param() == override_generation_config
|
assert model_config.get_diff_sampling_param() == override_generation_config
|
||||||
|
|||||||
@ -255,7 +255,7 @@ class ModelConfig:
|
|||||||
override_neuron_config: Optional[dict[str, Any]] = None,
|
override_neuron_config: Optional[dict[str, Any]] = None,
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||||
logits_processor_pattern: Optional[str] = None,
|
logits_processor_pattern: Optional[str] = None,
|
||||||
generation_config: Optional[str] = None,
|
generation_config: str = "auto",
|
||||||
enable_sleep_mode: bool = False,
|
enable_sleep_mode: bool = False,
|
||||||
override_generation_config: Optional[dict[str, Any]] = None,
|
override_generation_config: Optional[dict[str, Any]] = None,
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
@ -951,7 +951,7 @@ class ModelConfig:
|
|||||||
return self.multimodal_config
|
return self.multimodal_config
|
||||||
|
|
||||||
def try_get_generation_config(self) -> dict[str, Any]:
|
def try_get_generation_config(self) -> dict[str, Any]:
|
||||||
if self.generation_config is None or self.generation_config == "auto":
|
if self.generation_config in ("auto", "vllm"):
|
||||||
config = try_get_generation_config(
|
config = try_get_generation_config(
|
||||||
self.hf_config_path or self.model,
|
self.hf_config_path or self.model,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
@ -971,17 +971,14 @@ class ModelConfig:
|
|||||||
def get_diff_sampling_param(self) -> dict[str, Any]:
|
def get_diff_sampling_param(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
This method returns a dictionary containing the parameters
|
This method returns a dictionary containing the parameters
|
||||||
that differ from the default sampling parameters, but only
|
that differ from the default sampling parameters. If
|
||||||
if `generation_config` is set. If `generation_config` is not
|
`generation_config` is `"vllm"`, an empty dictionary is returned.
|
||||||
set, an empty dictionary is returned.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Any]: A dictionary with the differing sampling
|
dict[str, Any]: A dictionary with the differing sampling
|
||||||
parameters if `generation_config` is set, otherwise an
|
parameters, if `generation_config` is `"vllm"` an empty dictionary.
|
||||||
empty dictionary.
|
|
||||||
"""
|
"""
|
||||||
if self.generation_config is None:
|
if self.generation_config == "vllm":
|
||||||
# When generation_config is not set
|
|
||||||
config = {}
|
config = {}
|
||||||
else:
|
else:
|
||||||
config = self.try_get_generation_config()
|
config = self.try_get_generation_config()
|
||||||
|
|||||||
@ -207,7 +207,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
|
||||||
generation_config: Optional[str] = None
|
generation_config: Optional[str] = "auto"
|
||||||
override_generation_config: Optional[Dict[str, Any]] = None
|
override_generation_config: Optional[Dict[str, Any]] = None
|
||||||
enable_sleep_mode: bool = False
|
enable_sleep_mode: bool = False
|
||||||
model_impl: str = "auto"
|
model_impl: str = "auto"
|
||||||
@ -1034,13 +1034,13 @@ class EngineArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generation-config",
|
"--generation-config",
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
default=None,
|
default="auto",
|
||||||
help="The folder path to the generation config. "
|
help="The folder path to the generation config. "
|
||||||
"Defaults to None, no generation config is loaded, vLLM defaults "
|
"Defaults to 'auto', the generation config will be loaded from "
|
||||||
"will be used. If set to 'auto', the generation config will be "
|
"model path. If set to 'vllm', no generation config is loaded, "
|
||||||
"loaded from model path. If set to a folder path, the generation "
|
"vLLM defaults will be used. If set to a folder path, the "
|
||||||
"config will be loaded from the specified folder path. If "
|
"generation config will be loaded from the specified folder path. "
|
||||||
"`max_new_tokens` is specified in generation config, then "
|
"If `max_new_tokens` is specified in generation config, then "
|
||||||
"it sets a server-wide limit on the number of output tokens "
|
"it sets a server-wide limit on the number of output tokens "
|
||||||
"for all requests.")
|
"for all requests.")
|
||||||
|
|
||||||
|
|||||||
@ -109,8 +109,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
self.model_config.get_diff_sampling_param())
|
self.model_config.get_diff_sampling_param())
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info("Overwriting default chat sampling param with: %s",
|
source = self.model_config.generation_config
|
||||||
self.default_sampling_params)
|
source = "model" if source == "auto" else source
|
||||||
|
logger.info("Using default chat sampling params from %s: %s",
|
||||||
|
source, self.default_sampling_params)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -55,9 +55,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
self.model_config.get_diff_sampling_param())
|
self.model_config.get_diff_sampling_param())
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info(
|
source = self.model_config.generation_config
|
||||||
"Overwriting default completion sampling param with: %s",
|
source = "model" if source == "auto" else source
|
||||||
self.default_sampling_params)
|
logger.info("Using default completion sampling params from %s: %s",
|
||||||
|
source, self.default_sampling_params)
|
||||||
|
|
||||||
async def create_completion(
|
async def create_completion(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user