mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:26:41 +08:00
[Bugfix] Use HF config fields as fallback when loading Mistral config (#29239)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
730bd35378
commit
d1cf8214e5
@ -754,6 +754,7 @@ steps:
|
|||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/models/
|
- vllm/model_executor/models/
|
||||||
|
- vllm/transformers_utils/
|
||||||
- tests/models/test_initialization.py
|
- tests/models/test_initialization.py
|
||||||
commands:
|
commands:
|
||||||
# Only when vLLM model source is modified - test initialization of a large
|
# Only when vLLM model source is modified - test initialization of a large
|
||||||
|
|||||||
@ -691,6 +691,7 @@ steps:
|
|||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/models/
|
- vllm/model_executor/models/
|
||||||
|
- vllm/transformers_utils/
|
||||||
- tests/models/test_initialization.py
|
- tests/models/test_initialization.py
|
||||||
commands:
|
commands:
|
||||||
# Only when vLLM model source is modified - test initialization of a large
|
# Only when vLLM model source is modified - test initialization of a large
|
||||||
|
|||||||
@ -204,7 +204,19 @@ class MistralConfigParser(ConfigParserBase):
|
|||||||
|
|
||||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||||
|
|
||||||
config = adapt_config_dict(config_dict)
|
# Get missing fields from HF config if available
|
||||||
|
try:
|
||||||
|
hf_config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=_get_hf_token(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except OSError: # Not found
|
||||||
|
hf_config_dict = {}
|
||||||
|
|
||||||
|
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
|
||||||
|
|
||||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||||
# to int and add the layer_types list[str] to make it HF compatible
|
# to int and add the layer_types list[str] to make it HF compatible
|
||||||
|
|||||||
@ -9,14 +9,18 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
|
def adapt_config_dict(
|
||||||
config_dict.update(kwargs)
|
config_dict: dict[str, Any],
|
||||||
|
defaults: dict[str, Any],
|
||||||
|
) -> PretrainedConfig:
|
||||||
config_dict = _remap_general_mistral_args(config_dict)
|
config_dict = _remap_general_mistral_args(config_dict)
|
||||||
|
|
||||||
if bool(config_dict.get("quantization")):
|
if bool(config_dict.get("quantization")):
|
||||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||||
|
|
||||||
if bool(config_dict.get("moe")):
|
if config_dict.get("model_type") == "mamba":
|
||||||
|
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
||||||
|
elif bool(config_dict.get("moe")):
|
||||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||||
else:
|
else:
|
||||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||||
@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
|
|||||||
if is_audio:
|
if is_audio:
|
||||||
config_dict = _remap_mistral_audio_args(config_dict)
|
config_dict = _remap_mistral_audio_args(config_dict)
|
||||||
|
|
||||||
|
for k, v in defaults.items():
|
||||||
|
config_dict.setdefault(k, v)
|
||||||
|
|
||||||
config = PretrainedConfig.from_dict(config_dict)
|
config = PretrainedConfig.from_dict(config_dict)
|
||||||
|
|
||||||
logger.debug("Initialized config %s", config)
|
logger.debug("Initialized config %s", config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user