Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-11 20:45:23 +00:00 committed by GitHub
parent 92fea56fd1
commit cf3eacfe58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
83 changed files with 260 additions and 314 deletions

View File

@ -99,7 +99,6 @@ def benchmark_mrope(
# the parameters to compute the q k v size based on tp_size # the parameters to compute the q k v size based on tp_size
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,

View File

@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
def benchmark(batch_size, seq_len, num_heads, provider): def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
max_position = 8192 max_position = 8192
base = 10000 rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=device) rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)

View File

@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module): class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000): def __init__(self, head_dim=64, max_position=2048, base=10000):
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base}, rope_parameters={"rope_type": "default", "rope_theta": base},
) )
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base}, rope_parameters={"rope_type": "default", "rope_theta": base},
) )

View File

@ -116,7 +116,6 @@ def test_mrope(
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,

View File

@ -83,8 +83,12 @@ def test_rotary_embedding(
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} rope_parameters = {
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters) "rope_type": "default",
"rope_theta": rope_theta,
"partial_rotary_factor": rotary_dim / head_size,
}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device()) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
@ -150,9 +154,9 @@ def test_rope_module_cache():
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim,
max_position, max_position,
is_neox_style, is_neox_style,
rope_parameters, rope_parameters,
@ -177,9 +181,9 @@ def test_rope_module_cache():
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim,
max_position, max_position,
is_neox_style, is_neox_style,
rope_parameters, rope_parameters,

View File

@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
) )
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
""" """
A helper function that retrieves an attribute from an object which may A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances. arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
""" """
for name in names: for i, name in enumerate(names):
if hasattr(object, name): if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name) return getattr(object, name)
return default return default

View File

@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope( def get_rope(
head_size: int, head_size: int,
rotary_dim: int,
max_position: int, max_position: int,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None, rope_parameters: dict[str, Any] | None = None,
@ -54,12 +53,15 @@ def get_rope(
else: else:
dual_chunk_attention_args = None dual_chunk_attention_args = None
partial_rotary_factor = 1.0 rope_parameters = rope_parameters or {}
if rope_parameters is not None: base = rope_parameters.get("rope_theta", 10000)
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = ( key = (
head_size, head_size,
rotary_dim, rotary_dim,
@ -72,7 +74,6 @@ def get_rope(
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None: if dual_chunk_attention_config is not None:
extra_kwargs = { extra_kwargs = {
k: v k: v
@ -88,109 +89,76 @@ def get_rope(
dtype, dtype,
**extra_kwargs, **extra_kwargs,
) )
elif not rope_parameters: elif scaling_type == "default":
rotary_emb = RotaryEmbedding( if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype head_size, rotary_dim, max_position, base, is_neox_style, dtype
) )
else: elif scaling_type == "linear":
scaling_type = rope_parameters["rope_type"] scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
if scaling_type == "llama3": head_size,
scaling_factor = rope_parameters["factor"] rotary_dim,
low_freq_factor = rope_parameters["low_freq_factor"] max_position,
high_freq_factor = rope_parameters["high_freq_factor"] base,
original_max_position = rope_parameters["original_max_position_embeddings"] is_neox_style,
rotary_emb = Llama3RotaryEmbedding( scaling_factor,
head_size, dtype,
rotary_dim, )
max_position, elif scaling_type == "ntk":
base, scaling_factor = rope_parameters["factor"]
is_neox_style, mixed_b = rope_parameters.get("mixed_b")
dtype, rotary_emb = NTKScalingRotaryEmbedding(
scaling_factor, head_size,
low_freq_factor, rotary_dim,
high_freq_factor, max_position,
original_max_position, base,
) is_neox_style,
elif scaling_type == "mllama4": scaling_factor,
rotary_emb = Llama4VisionRotaryEmbedding( dtype,
head_size, rotary_dim, max_position, base, is_neox_style, dtype mixed_b,
) )
elif scaling_type == "default": elif scaling_type == "dynamic":
if "mrope_section" in rope_parameters: if "alpha" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"] scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding( rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
@ -198,67 +166,66 @@ def get_rope(
is_neox_style, is_neox_style,
scaling_alpha, scaling_alpha,
dtype, dtype,
xdrope_section=rope_parameters["xdrope_section"],
) )
elif scaling_type == "yarn": elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"] rotary_emb = DynamicNTKScalingRotaryEmbedding(
extra_kwargs = { head_size,
k: v rotary_dim,
for k, v in rope_parameters.items() max_position,
if k base,
in ( is_neox_style,
"extrapolation_factor", scaling_factor,
"attn_factor", dtype,
"beta_fast", )
"beta_slow", else:
"apply_yarn_scaling", raise ValueError(
"truncate", "Dynamic rope scaling must contain either 'alpha' or 'factor' field"
) )
} elif scaling_type == "xdrope":
if "mrope_section" in rope_parameters: scaling_alpha = rope_parameters["alpha"]
extra_kwargs.pop("apply_yarn_scaling", None) rotary_emb = XDRotaryEmbedding(
rotary_emb = MRotaryEmbedding( head_size,
head_size, rotary_dim,
rotary_dim, max_position,
original_max_position, base,
base, is_neox_style,
is_neox_style, scaling_alpha,
dtype, dtype,
mrope_section=rope_parameters["mrope_section"], xdrope_section=rope_parameters["xdrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False), )
scaling_factor=scaling_factor, elif scaling_type == "yarn":
**extra_kwargs, scaling_factor = rope_parameters["factor"]
) original_max_position = rope_parameters["original_max_position_embeddings"]
else: extra_kwargs = {
rotary_emb = YaRNScalingRotaryEmbedding( k: v
head_size, for k, v in rope_parameters.items()
rotary_dim, if k
original_max_position, in (
base, "extrapolation_factor",
is_neox_style, "attn_factor",
scaling_factor, "beta_fast",
dtype, "beta_slow",
**extra_kwargs, "apply_yarn_scaling",
) "truncate",
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]: )
scaling_factor = rope_parameters["factor"] }
original_max_position = rope_parameters["original_max_position_embeddings"] if "mrope_section" in rope_parameters:
# assert max_position == original_max_position * scaling_factor extra_kwargs.pop("apply_yarn_scaling", None)
extra_kwargs = { rotary_emb = MRotaryEmbedding(
k: v head_size,
for k, v in rope_parameters.items() rotary_dim,
if k original_max_position,
in ( base,
"extrapolation_factor", is_neox_style,
"attn_factor", dtype,
"beta_fast", mrope_section=rope_parameters["mrope_section"],
"beta_slow", mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
"mscale", scaling_factor=scaling_factor,
"mscale_all_dim", **extra_kwargs,
) )
} else:
rotary_emb = DeepseekScalingRotaryEmbedding( rotary_emb = YaRNScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
original_max_position, original_max_position,
@ -268,28 +235,55 @@ def get_rope(
dtype, dtype,
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "longrope": elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
short_factor = rope_parameters["short_factor"] scaling_factor = rope_parameters["factor"]
long_factor = rope_parameters["long_factor"] original_max_position = rope_parameters["original_max_position_embeddings"]
original_max_position = rope_parameters["original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_parameters.items() for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale") if k
} in (
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( "extrapolation_factor",
head_size, "attn_factor",
rotary_dim, "beta_fast",
max_position, "beta_slow",
original_max_position, "mscale",
base, "mscale_all_dim",
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
) )
else: }
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
return rotary_emb return rotary_emb

View File

@ -241,7 +241,6 @@ class AfmoeAttention(nn.Module):
if self.is_local_attention: if self.is_local_attention:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config["rope_parameters"], rope_parameters=config["rope_parameters"],
is_neox_style=True, is_neox_style=True,

View File

@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) rotary_dim = getattr(config, "rotary_dim", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"): rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
rotary_dim = config.attn_rotary_emb # for backward compatibility config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
else:
rotary_dim = self.head_dim # default
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0) rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192) max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio} rope_parameters = {
"rope_type": "default",
"rope_theta": 10000 * rope_ratio,
"partial_rotary_factor": 0.5,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False, # NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True # which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope is_neox_style = not config.original_rope
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions, max_position=max_positions,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,

View File

@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu" config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings, "max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if not model_config.enforce_eager: if not model_config.enforce_eager:
max_position = round_up(max_position, 8) max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": max_position, "max_position": max_position,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config.num_hidden_layers = config.n_layer config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
max_trained_positions = getattr(config, "max_trained_positions", 2048) max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions, "max_position": max_trained_positions,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu" config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings, "max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }

View File

@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position, max_position=self.max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -156,7 +156,6 @@ class DeepseekAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
@ -499,7 +498,6 @@ class DeepseekV2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,
@ -1018,7 +1016,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,
@ -1038,7 +1035,6 @@ class DeepseekV2MLAAttention(nn.Module):
if self.is_v32: if self.is_v32:
self.indexer_rope_emb = get_rope( self.indexer_rope_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=False, is_neox_style=False,

View File

@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module):
set_default_rope_theta(config, default_theta=1000000) set_default_rope_theta(config, default_theta=1000000)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -167,7 +167,6 @@ class FalconAttention(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -242,14 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"): rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
rotary_dim = config.attn_rotary_emb # for backward compatibility config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
else:
rotary_dim = self.head_dim # default
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -174,7 +174,6 @@ class GemmaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -152,7 +152,6 @@ class Gemma2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -176,7 +176,6 @@ class Gemma3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -384,7 +384,6 @@ class Gemma3nAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -81,7 +81,6 @@ class Glm4Attention(nn.Module):
config.rope_parameters.setdefault("partial_rotary_factor", 0.5) config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = self.head_dim
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@ -103,7 +102,6 @@ class Glm4Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,

View File

@ -678,9 +678,9 @@ class Glm4vVisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [

View File

@ -285,7 +285,6 @@ class Glm4MoeAttention(nn.Module):
config.rope_parameters.setdefault("partial_rotary_factor", 0.5) config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -95,12 +95,13 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
assert getattr(config, "rotary", True) assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0 assert config.rotary_dim % 2 == 0
rope_parameters = getattr(config, "rope_parameters", {})
rope_parameters["partial_rotary_factor"] = config.rotary_dim / self.head_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_size, self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None), rope_parameters=rope_parameters,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(

View File

@ -92,7 +92,6 @@ class GPTNeoXAttention(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_size, self.head_size,
rotary_dim=self.head_size,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -67,7 +67,6 @@ class OAIAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
dtype=torch.float32, dtype=torch.float32,
rope_parameters={ rope_parameters={

View File

@ -160,7 +160,6 @@ class GraniteAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -190,7 +190,6 @@ class GraniteMoeAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -271,7 +271,6 @@ class GraniteMoeHybridAttention(nn.Module):
if config.position_embedding_type == "rope": if config.position_embedding_type == "rope":
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -181,7 +181,6 @@ class Grok1Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -199,7 +199,6 @@ class HunYuanAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
@ -305,7 +304,6 @@ class HunYuanCrossAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None), rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
self.rotary_emb = ( self.rotary_emb = (
get_rope( get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.qk_rope_head_dim, self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
if (
rope_parameters is not None
and "partial_rotary_factor" not in rope_parameters
):
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
num_kv_heads: int, num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_parameters: dict | None = None, rope_parameters: dict | None = None,
sliding_window: int | None = None, sliding_window: int | None = None,
@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,
@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
head_dim = getattr(config, "head_dim", None) head_dim = getattr(config, "head_dim", None)
if head_dim is None: if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = min( max_position_embeddings = min(
config.max_position_embeddings, config.max_model_len config.max_position_embeddings, config.max_model_len
@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
head_dim=head_dim, head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim")
else head_dim,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,

View File

@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
rope_parameters = { rope_parameters = {
"rope_type": "mllama4", "rope_type": "mllama4",
"rope_theta": config.rope_parameters["rope_theta"], "rope_theta": config.rope_parameters["rope_theta"],
"partial_rotary_factor": 0.5,
} }
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches # number of image patches
max_position=(config.image_size // config.patch_size) ** 2, max_position=(config.image_size // config.patch_size) ** 2,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,

View File

@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dtype=torch.float16, dtype=torch.float16,

View File

@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module):
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -352,7 +352,6 @@ class OpenPanguMLAAttention(nn.Module):
} }
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=False, is_neox_style=False,
@ -525,7 +524,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,

View File

@ -135,7 +135,6 @@ class OrionAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -166,7 +166,6 @@ class OuroAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,

View File

@ -134,7 +134,6 @@ class PersimmonAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -84,19 +84,18 @@ class PhiAttention(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = config.num_attention_heads // tensor_model_parallel_world_size
# pylint: disable=C0103 # pylint: disable=C0103
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_size, self.head_size,
self.total_num_heads, config.num_attention_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
@ -109,13 +108,10 @@ class PhiAttention(nn.Module):
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = config.hidden_size // config.num_attention_heads
assert rotary_dim % 2 == 0
max_position_embeddings = getattr(config, "max_position_embeddings", 2048) max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_size, self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -352,7 +352,6 @@ class PhiMoEAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -574,7 +574,6 @@ class Plamo2AttentionMixer(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -179,7 +179,6 @@ class Plamo3AttentionMixer(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -114,7 +114,6 @@ class QWenAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -164,7 +164,6 @@ class Qwen2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,

View File

@ -624,9 +624,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(

View File

@ -244,7 +244,6 @@ class Qwen2MoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,

View File

@ -621,9 +621,9 @@ class Qwen2VisionTransformer(nn.Module):
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(

View File

@ -111,7 +111,6 @@ class Qwen3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,

View File

@ -269,7 +269,6 @@ class Qwen3MoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,

View File

@ -747,7 +747,6 @@ class Qwen3NextAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
dual_chunk_attention_config=self.dual_chunk_attention_config, dual_chunk_attention_config=self.dual_chunk_attention_config,

View File

@ -333,9 +333,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(

View File

@ -340,9 +340,9 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3_VisionPatchMerger(

View File

@ -161,7 +161,6 @@ class SeedOssAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -160,7 +160,6 @@ class SolarAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )

View File

@ -148,7 +148,6 @@ class StablelmAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.config.max_position_embeddings, max_position=self.config.max_position_embeddings,
rope_parameters=self.config.rope_parameters, rope_parameters=self.config.rope_parameters,
) )

View File

@ -112,7 +112,6 @@ class Starcoder2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -196,7 +196,6 @@ class Step3TextAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding, max_position=max_position_embedding,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )

View File

@ -230,7 +230,6 @@ class Zamba2Attention(nn.Module):
if config.use_mem_rope: if config.use_mem_rope:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.attention_head_dim, head_size=self.attention_head_dim,
rotary_dim=self.attention_head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,

View File

@ -306,8 +306,13 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE.""" """Provide backwards compatibility for RoPE."""
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
rope_theta_names = ("rope_theta", "rotary_emb_base") # Older custom models may use non-standard field names
rope_theta = getattr_iter(config, rope_theta_names, None) # which need patching for both Transformers v4 and v5.
names = ["rope_theta", "rotary_emb_base"]
rope_theta = getattr_iter(config, names, None, warn=True)
names = ["partial_rotary_factor", "rotary_pct", "rotary_emb_fraction"]
partial_rotary_factor = getattr_iter(config, names, None, warn=True)
if Version(version("transformers")) < Version("5.0.0.dev0"): if Version(version("transformers")) < Version("5.0.0.dev0"):
# Transformers v4 installed, legacy config fields may be present # Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
@ -316,14 +321,18 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
if not hasattr(config, "rope_parameters"): if not hasattr(config, "rope_parameters"):
config.rope_parameters = {"rope_type": "default"} config.rope_parameters = {"rope_type": "default"}
config.rope_parameters["rope_theta"] = rope_theta config.rope_parameters["rope_theta"] = rope_theta
partial_rotary_factor_names = ("partial_rotary_factor", "rotary_pct")
partial_rotary_factor = getattr_iter(config, partial_rotary_factor_names, None)
if partial_rotary_factor is not None: if partial_rotary_factor is not None:
if not hasattr(config, "rope_parameters"): if not hasattr(config, "rope_parameters"):
config.rope_parameters = {"rope_type": "default"} config.rope_parameters = {"rope_type": "default"}
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
elif rope_theta is not None or hasattr(config, "rope_parameters"): elif rope_theta is not None or hasattr(config, "rope_parameters"):
# Transformers v5 installed # Transformers v5 installed
# Patch these fields in case they used non-standard names
if rope_theta is not None:
config.rope_theta = rope_theta
if partial_rotary_factor is not None:
config.partial_rotary_factor = partial_rotary_factor
# Standardize and validate RoPE parameters
config.standardize_rope_params() config.standardize_rope_params()
config.validate_rope() config.validate_rope()