Only patch original_max_position_embeddings for Transformers v4 (#31214)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-23 16:46:32 +00:00 committed by GitHub
parent b94f80ffb8
commit 1339878e13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -330,19 +330,25 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
rope_theta = getattr_iter(config, names, None, warn=True)
names = ["partial_rotary_factor", "rotary_pct", "rotary_emb_fraction"]
partial_rotary_factor = getattr_iter(config, names, None, warn=True)
ompe = getattr(config, "original_max_position_embeddings", None)
if Version(version("transformers")) < Version("5.0.0.dev0"):
# Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling
if (
rope_theta is not None or partial_rotary_factor is not None
rope_theta is not None
or partial_rotary_factor is not None
or ompe is not None
) and not getattr(config, "rope_parameters", None):
config.rope_parameters = {"rope_type": "default"}
# Patch legacy fields into rope_parameters
if rope_theta is not None:
config.rope_parameters["rope_theta"] = rope_theta
if partial_rotary_factor is not None:
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
if ompe is not None:
config.rope_parameters["original_max_position_embeddings"] = ompe
elif rope_theta is not None or getattr(config, "rope_parameters", None):
# Transformers v5 installed
# Patch these fields in case they used non-standard names
@ -358,10 +364,6 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
if getattr(config, "rope_parameters", None) is None:
return
# Add original_max_position_embeddings if present
if ompe := getattr(config, "original_max_position_embeddings", None):
config.rope_parameters["original_max_position_embeddings"] = ompe
# Handle nested rope_parameters in interleaved sliding attention models
if is_rope_parameters_nested(config.rope_parameters):
for rope_parameters_layer_type in config.rope_parameters.values():