[BugFix] Correct max_model_len derivation from config.json for Mistral format (#17937)

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Co-authored-by: tracelogfb <48808670+tracelogfb@users.noreply.github.com>
Co-authored-by: Stephen Chen <tracelog@meta.com>
This commit is contained in:
汪志鹏 2025-05-16 21:20:13 -07:00 committed by GitHub
parent 60017dc841
commit 4ee4826ede
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -686,9 +686,24 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)
if config_dict.get("max_position_embeddings") is None:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
config_dict["max_position_embeddings"] = max_position_embeddings
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})