Fix RoPE related failures in Transformers nightly tests (#29333)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-11-25 16:23:45 +00:00 committed by GitHub
parent a1f2676879
commit 0353d2e162
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 33 deletions

View File

@ -233,7 +233,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding, position_embedding=position_embedding,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,

View File

@ -100,7 +100,7 @@ class GPTJAttention(nn.Module):
self.head_size, self.head_size,
rotary_dim=config.rotary_dim, rotary_dim=config.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(

View File

@ -239,7 +239,7 @@ class Grok1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",

View File

@ -262,7 +262,7 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor, partial_rotary_factor=self.partial_rotary_factor,
) )

View File

@ -456,51 +456,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No
def patch_rope_parameters(config: PretrainedConfig) -> None: def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE.""" """Provide backwards compatibility for RoPE."""
# Retrieve rope_parameters differently based on Transformers version # Patch rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"): if Version(version("transformers")) >= Version("5.0.0.dev0"):
from transformers.modeling_rope_utils import RopeParameters from transformers.modeling_rope_utils import (
rope_config_validation,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr( standardize_rope_params,
config, "rope_parameters", None
) )
elif hasattr(config, "rope_parameters"):
# We are in Transformers v4 and rope_parameters # When Transformers v5 is installed, legacy rope_theta may be present
# has already been patched for this config # when using custom code models written for Transformers v4
return if (rope_theta := getattr(config, "rope_theta", None)) is not None:
standardize_rope_params(config, rope_theta=rope_theta)
rope_config_validation(config)
# Delete rope_theta to avoid confusion in downstream code
del config.rope_theta
else: else:
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters # When Transformers v4 is installed, legacy rope_scaling may be present
rope_theta: float | None = getattr(config, "rope_theta", None) if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
rope_scaling: dict | None = getattr(config, "rope_scaling", None) config.rope_parameters = rope_scaling
rope_parameters = rope_scaling # When Transformers v4 is installed, legacy rope_theta may be present
# Move rope_theta into rope_parameters if (rope_theta := getattr(config, "rope_theta", None)) is not None:
if rope_theta is not None: if not hasattr(config, "rope_parameters"):
rope_parameters = rope_parameters or {"rope_type": "default"} config.rope_parameters = {"rope_type": "default"}
rope_parameters["rope_theta"] = rope_theta config.rope_parameters["rope_theta"] = rope_theta
# Add original_max_position_embeddings if present
if rope_parameters and (
ompe := getattr(config, "original_max_position_embeddings", None)
):
rope_parameters["original_max_position_embeddings"] = ompe
# Write back to config
config.rope_parameters = rope_parameters
# No RoPE parameters to patch # No RoPE parameters to patch
if rope_parameters is None: if not hasattr(config, "rope_parameters"):
return return
# Add original_max_position_embeddings if present
if ompe := getattr(config, "original_max_position_embeddings", None):
config.rope_parameters["original_max_position_embeddings"] = ompe
# Handle nested rope_parameters in interleaved sliding attention models # Handle nested rope_parameters in interleaved sliding attention models
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES): if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
for rope_parameters_layer_type in rope_parameters.values(): for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type) patch_rope_parameters_dict(rope_parameters_layer_type)
else: else:
patch_rope_parameters_dict(rope_parameters) patch_rope_parameters_dict(config.rope_parameters)
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None: def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters: if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"] rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"] rope_type_legacy = rope_parameters["type"]
if rope_type != rope_type_legacy: if (rope_type_legacy == "su" and rope_type == "longrope") or (
rope_type_legacy == "mrope" and rope_type == "default"
):
pass # No action needed
elif rope_type != rope_type_legacy:
raise ValueError( raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern " f"Found conflicts between 'rope_type={rope_type}' (modern "
f"field) and 'type={rope_type_legacy}' (legacy field). " f"field) and 'type={rope_type_legacy}' (legacy field). "