mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[Frontend][Core] Override HF config.json via CLI (#5836)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d88bff1b96
commit
b09895a618
@ -200,8 +200,10 @@ def test_rope_customization():
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
rope_theta=TEST_ROPE_THETA,
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
"rope_theta": TEST_ROPE_THETA,
|
||||
},
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
@ -232,7 +234,9 @@ def test_rope_customization():
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
},
|
||||
)
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import enum
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
@ -74,9 +75,6 @@ class ModelConfig:
|
||||
code_revision: The specific revision to use for the model code on
|
||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
rope_scaling: Dictionary containing the scaling configuration for the
|
||||
RoPE embeddings. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
@ -116,6 +114,7 @@ class ModelConfig:
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
pooling_type: Used to configure the pooling method in the embedding
|
||||
@ -146,7 +145,7 @@ class ModelConfig:
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
@ -164,6 +163,7 @@ class ModelConfig:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
hf_overrides: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
@ -178,8 +178,22 @@ class ModelConfig:
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
if rope_scaling is not None:
|
||||
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||
hf_overrides.update(hf_override)
|
||||
msg = ("`--rope-scaling` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
if rope_theta is not None:
|
||||
hf_override = {"rope_theta": rope_theta}
|
||||
hf_overrides.update(hf_override)
|
||||
msg = ("`--rope-theta` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
# The tokenizer version is consistent with the model version by default.
|
||||
if tokenizer_revision is None:
|
||||
self.tokenizer_revision = revision
|
||||
@ -193,8 +207,8 @@ class ModelConfig:
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, rope_scaling, rope_theta,
|
||||
config_format)
|
||||
code_revision, config_format,
|
||||
**hf_overrides)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
|
||||
@ -128,8 +128,9 @@ class EngineArgs:
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
rope_scaling: Optional[Dict[str, Any]] = None
|
||||
rope_theta: Optional[float] = None
|
||||
hf_overrides: Optional[Dict[str, Any]] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: Optional[bool] = None
|
||||
@ -140,8 +141,9 @@ class EngineArgs:
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
||||
tokenizer_pool_extra_config: Optional[dict] = None
|
||||
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
@ -187,7 +189,6 @@ class EngineArgs:
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
disable_async_output_proc: bool = False
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
||||
|
||||
# Pooling configuration.
|
||||
@ -512,6 +513,12 @@ class EngineArgs:
|
||||
help='RoPE theta. Use with `rope_scaling`. In '
|
||||
'some cases, changing the RoPE theta improves the '
|
||||
'performance of the scaled model.')
|
||||
parser.add_argument('--hf-overrides',
|
||||
type=json.loads,
|
||||
default=EngineArgs.hf_overrides,
|
||||
help='Extra arguments for the HuggingFace config.'
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
@ -940,6 +947,7 @@ class EngineArgs:
|
||||
code_revision=self.code_revision,
|
||||
rope_scaling=self.rope_scaling,
|
||||
rope_theta=self.rope_theta,
|
||||
hf_overrides=self.hf_overrides,
|
||||
tokenizer_revision=self.tokenizer_revision,
|
||||
max_model_len=self.max_model_len,
|
||||
quantization=self.quantization,
|
||||
|
||||
@ -248,8 +248,7 @@ class LLMEngine:
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, "
|
||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
||||
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
@ -271,8 +270,6 @@ class LLMEngine:
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.override_neuron_config,
|
||||
model_config.rope_scaling,
|
||||
model_config.rope_theta,
|
||||
model_config.tokenizer_revision,
|
||||
model_config.trust_remote_code,
|
||||
model_config.dtype,
|
||||
|
||||
@ -98,7 +98,10 @@ class LLM:
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
|
||||
disable_async_output_proc: Disable async output processing.
|
||||
This may result in lower performance.
|
||||
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
|
||||
@ -153,6 +156,7 @@ class LLM:
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[dict] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
@ -194,6 +198,7 @@ class LLM:
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooling_type=pooling_type,
|
||||
pooling_norm=pooling_norm,
|
||||
|
||||
@ -146,9 +146,8 @@ def get_config(
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
token: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> PretrainedConfig:
|
||||
# Separate model folder from file path for GGUF models
|
||||
@ -159,39 +158,43 @@ def get_config(
|
||||
model = Path(model).parent
|
||||
|
||||
if config_format == ConfigFormat.AUTO:
|
||||
if is_gguf or file_or_path_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=kwargs.get("token")):
|
||||
if is_gguf or file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision, token=token):
|
||||
config_format = ConfigFormat.HF
|
||||
elif file_or_path_exists(model,
|
||||
MISTRAL_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=kwargs.get("token")):
|
||||
token=token):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
else:
|
||||
# If we're in offline mode and found no valid config format, then
|
||||
# raise an offline mode error to indicate to the user that they
|
||||
# don't have files cached and may need to go online.
|
||||
# This is conveniently triggered by calling file_exists().
|
||||
file_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=kwargs.get("token"))
|
||||
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
||||
|
||||
raise ValueError(f"No supported config format found in {model}")
|
||||
|
||||
if config_format == ConfigFormat.HF:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model, revision=revision, code_revision=code_revision, **kwargs)
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
config = config_class.from_pretrained(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
@ -199,6 +202,7 @@ def get_config(
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -216,7 +220,7 @@ def get_config(
|
||||
raise e
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
config = load_params_config(model, revision, token=kwargs.get("token"))
|
||||
config = load_params_config(model, revision, token=token, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config format: {config_format}")
|
||||
|
||||
@ -228,19 +232,6 @@ def get_config(
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
config.update({"architectures": [model_type]})
|
||||
|
||||
for key, value in [
|
||||
("rope_scaling", rope_scaling),
|
||||
("rope_theta", rope_theta),
|
||||
]:
|
||||
if value is not None:
|
||||
logger.info(
|
||||
"Updating %s from %r to %r",
|
||||
key,
|
||||
getattr(config, key, None),
|
||||
value,
|
||||
)
|
||||
config.update({key: value})
|
||||
|
||||
patch_rope_scaling(config)
|
||||
|
||||
return config
|
||||
@ -462,13 +453,15 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
||||
|
||||
def load_params_config(model: Union[str, Path],
|
||||
revision: Optional[str],
|
||||
token: Optional[str] = None) -> PretrainedConfig:
|
||||
token: Optional[str] = None,
|
||||
**kwargs) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
||||
config_file_name = "params.json"
|
||||
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
||||
assert isinstance(config_dict, dict)
|
||||
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
@ -512,6 +505,8 @@ def load_params_config(model: Union[str, Path],
|
||||
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
|
||||
config_dict["model_type"] = "pixtral"
|
||||
|
||||
config_dict.update(kwargs)
|
||||
|
||||
config = recurse_elems(config_dict)
|
||||
return config
|
||||
|
||||
|
||||
@ -74,8 +74,7 @@ class LLMEngine:
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, "
|
||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
||||
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
@ -94,8 +93,6 @@ class LLMEngine:
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.override_neuron_config,
|
||||
model_config.rope_scaling,
|
||||
model_config.rope_theta,
|
||||
model_config.tokenizer_revision,
|
||||
model_config.trust_remote_code,
|
||||
model_config.dtype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user