mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:24:34 +08:00
[Bugfix] Use HF config fields as fallback when loading Mistral config (#29239)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
730bd35378
commit
d1cf8214e5
@ -754,6 +754,7 @@ steps:
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- vllm/transformers_utils/
|
||||
- tests/models/test_initialization.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
|
||||
@ -691,6 +691,7 @@ steps:
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- vllm/transformers_utils/
|
||||
- tests/models/test_initialization.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
|
||||
@ -204,7 +204,19 @@ class MistralConfigParser(ConfigParserBase):
|
||||
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
# Get missing fields from HF config if available
|
||||
try:
|
||||
hf_config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
except OSError: # Not found
|
||||
hf_config_dict = {}
|
||||
|
||||
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
|
||||
@ -9,14 +9,18 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
|
||||
config_dict.update(kwargs)
|
||||
def adapt_config_dict(
|
||||
config_dict: dict[str, Any],
|
||||
defaults: dict[str, Any],
|
||||
) -> PretrainedConfig:
|
||||
config_dict = _remap_general_mistral_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("quantization")):
|
||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("moe")):
|
||||
if config_dict.get("model_type") == "mamba":
|
||||
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
||||
elif bool(config_dict.get("moe")):
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
|
||||
if is_audio:
|
||||
config_dict = _remap_mistral_audio_args(config_dict)
|
||||
|
||||
for k, v in defaults.items():
|
||||
config_dict.setdefault(k, v)
|
||||
|
||||
config = PretrainedConfig.from_dict(config_dict)
|
||||
|
||||
logger.debug("Initialized config %s", config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user