[Bugfix] Use HF config fields as fallback when loading Mistral config (#29239)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-23 02:22:48 +08:00 committed by GitHub
parent 730bd35378
commit d1cf8214e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 4 deletions

View File

@ -754,6 +754,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
- vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large

View File

@ -691,6 +691,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
- vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large

View File

@ -204,7 +204,19 @@ class MistralConfigParser(ConfigParserBase):
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Get missing fields from HF config if available
try:
hf_config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except OSError: # Not found
hf_config_dict = {}
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible

View File

@ -9,14 +9,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
config_dict.update(kwargs)
def adapt_config_dict(
config_dict: dict[str, Any],
defaults: dict[str, Any],
) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
if bool(config_dict.get("moe")):
if config_dict.get("model_type") == "mamba":
config_dict["architectures"] = ["Mamba2ForCausalLM"]
elif bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
for k, v in defaults.items():
config_dict.setdefault(k, v)
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config %s", config)