From e10c84e06af7264d5c0b3e7ec5604ada2eee7094 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:42:49 +0000 Subject: [PATCH] Access `partial_rotary_factor` from `rope_parameters` (#29966) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/kernels/core/test_mrope.py | 8 ++------ .../layers/rotary_embedding/__init__.py | 5 ++++- vllm/model_executor/models/apertus.py | 5 +---- vllm/model_executor/models/bailing_moe.py | 3 --- vllm/model_executor/models/bamba.py | 4 +--- vllm/model_executor/models/config.py | 5 ----- vllm/model_executor/models/falcon_h1.py | 4 +--- vllm/model_executor/models/glm.py | 3 ++- vllm/model_executor/models/glm4.py | 3 +-- vllm/model_executor/models/glm4_moe.py | 3 +-- vllm/model_executor/models/gpt_neox.py | 6 ++---- vllm/model_executor/models/llama.py | 3 --- vllm/model_executor/models/nemotron.py | 2 -- vllm/model_executor/models/nemotron_nas.py | 1 - vllm/model_executor/models/persimmon.py | 2 -- vllm/model_executor/models/phi.py | 5 +---- vllm/model_executor/models/qwen3_next.py | 1 - vllm/model_executor/models/stablelm.py | 4 ---- vllm/transformers_utils/config.py | 10 +++++++++- vllm/transformers_utils/configs/nemotron.py | 20 ++++++++++++------- vllm/transformers_utils/configs/qwen3_next.py | 8 +++++--- 21 files changed, 43 insertions(+), 62 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 43b242ab2d586..4e1559a049bf9 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -113,12 +113,10 @@ def test_mrope( is_neox_style = True max_position = config.max_position_embeddings - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=rotary_dim, + rotary_dim=head_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=config.rope_parameters, @@ -184,12 +182,10 @@ def test_mrope_torch_compile_tracing( ) is_neox_style = True max_position = config.max_position_embeddings - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=rotary_dim, + rotary_dim=head_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=config.rope_parameters, diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index aa6ece30026d3..4dff984f92be6 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -30,7 +30,6 @@ def get_rope( is_neox_style: bool = True, rope_parameters: dict[str, Any] | None = None, dtype: torch.dtype | None = None, - partial_rotary_factor: float = 1.0, dual_chunk_attention_config: dict[str, Any] | None = None, ) -> RotaryEmbedding: if dtype is None: @@ -55,6 +54,10 @@ def get_rope( else: dual_chunk_attention_args = None + partial_rotary_factor = 1.0 + if rope_parameters is not None: + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = ( diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 4a69787af55e2..2a8be29d8d306 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -148,8 +148,6 @@ class ApertusAttention(nn.Module): if head_dim is None: head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim - # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -228,11 +226,10 @@ class ApertusAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=int(self.partial_rotary_factor * self.head_dim), + rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index f7a5d4e7889e5..0143e140af265 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -127,8 +127,6 @@ class BailingAttention(nn.Module): prefix=f"{prefix}.dense", ) - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) self.rotary_emb = get_rope( @@ -137,7 +135,6 @@ class BailingAttention(nn.Module): max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 1d6493b18c343..00d742f84ef79 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -178,9 +178,7 @@ class BambaAttentionDecoderLayer(nn.Module): self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings - if hasattr(config, "partial_rotary_factor"): - rotary_dim = int(self.head_dim * config.partial_rotary_factor) - elif hasattr(config, "attn_rotary_emb"): + if hasattr(config, "attn_rotary_emb"): rotary_dim = config.attn_rotary_emb # for backward compatibility else: rotary_dim = self.head_dim # default diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index d7e802ba1aca0..4bca36aa4b7de 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -8,7 +8,6 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform -from vllm.transformers_utils.config import set_default_rope_theta from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec @@ -78,8 +77,6 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): if not model_config.enforce_eager: max_position = round_up(max_position, 8) - set_default_rope_theta(config, default_theta=config.rotary_emb_base) - config.rotary_kwargs = { "head_size": head_dim, "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), @@ -119,8 +116,6 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): rotary_emb_dim = int(head_dim * config.rotary_emb_fraction) max_trained_positions = getattr(config, "max_trained_positions", 2048) - set_default_rope_theta(config, default_theta=config.rotary_emb_base) - config.rotary_kwargs = { "head_size": head_dim, "rotary_dim": rotary_emb_dim, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 83ceb9303cfb5..a1c1263f8d724 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -242,9 +242,7 @@ class FalconH1AttentionDecoderLayer(nn.Module): self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings - if hasattr(config, "partial_rotary_factor"): - rotary_dim = self.head_dim * config.partial_rotary_factor - elif hasattr(config, "attn_rotary_emb"): + if hasattr(config, "attn_rotary_emb"): rotary_dim = config.attn_rotary_emb # for backward compatibility else: rotary_dim = self.head_dim # default diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py index a6991f8e43fef..26d7c29aae6e2 100644 --- a/vllm/model_executor/models/glm.py +++ b/vllm/model_executor/models/glm.py @@ -10,7 +10,8 @@ from .utils import PPMissingLayer class GlmForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 + hf_config = vllm_config.model_config.hf_config + hf_config.rope_parameters["partial_rotary_factor"] = 0.5 super().__init__(vllm_config=vllm_config, prefix=prefix) # Hack Llama model to fit HF format GLM implementation # Attention difference between GLM and Llama: diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 002cdb721e1db..9adfa942b99fa 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -78,7 +78,7 @@ class Glm4Attention(nn.Module): # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + config.rope_parameters.setdefault("partial_rotary_factor", 0.5) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads self.rotary_dim = self.head_dim @@ -106,7 +106,6 @@ class Glm4Attention(nn.Module): rotary_dim=self.rotary_dim, max_position=max_position, rope_parameters=config.rope_parameters, - partial_rotary_factor=partial_rotary_factor, is_neox_style=False, ) self.attn = Attention( diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index c99f824e1bd4d..8cae5ee425e4d 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -282,13 +282,12 @@ class Glm4MoeAttention(nn.Module): prefix=f"{prefix}.o_proj", ) - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + config.rope_parameters.setdefault("partial_rotary_factor", 0.5) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b9959682cbcef..212d605c17285 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -89,16 +89,14 @@ class GPTNeoXAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.dense", ) - scaling = self.head_size**-0.5 - rotary_dim = int(self.head_size * config.rotary_pct) - assert rotary_dim % 2 == 0 max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, - rotary_dim=rotary_dim, + rotary_dim=self.head_size, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) + scaling = self.head_size**-0.5 self.attn = Attention( self.num_heads, self.head_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8f5a967cd422a..167dfbca248ce 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -149,8 +149,6 @@ class LlamaAttention(nn.Module): if head_dim is None: head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim - # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -265,7 +263,6 @@ class LlamaAttention(nn.Module): max_position=self.max_position_embeddings, rope_parameters=getattr(config, "rope_parameters", None), is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index ffba6c9dfe739..bf83ee5e42a15 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -178,7 +178,6 @@ class NemotronAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.partial_rotary_factor = config.partial_rotary_factor self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( @@ -203,7 +202,6 @@ class NemotronAttention(nn.Module): rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 9d968dee87114..734fbc60709fa 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -122,7 +122,6 @@ class DeciLMAttention(LlamaAttention): max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 795cd25f16753..8f26c68720a5c 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -106,7 +106,6 @@ class PersimmonAttention(nn.Module): self.num_heads = self.total_num_heads // tensor_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.partial_rotary_factor = config.partial_rotary_factor self.is_causal = True assert (self.head_dim * self.total_num_heads) == self.hidden_size @@ -138,7 +137,6 @@ class PersimmonAttention(nn.Module): rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 self.attn = Attention( diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 70016d9ed246c..253fbbc41330c 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -109,10 +109,7 @@ class PhiAttention(nn.Module): ) scaling = self.head_size**-0.5 - rotary_dim = int( - config.partial_rotary_factor - * (config.hidden_size // config.num_attention_heads) - ) + rotary_dim = config.hidden_size // config.num_attention_heads assert rotary_dim % 2 == 0 max_position_embeddings = getattr(config, "max_position_embeddings", 2048) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 661a182151d74..dd64e3983e381 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -750,7 +750,6 @@ class Qwen3NextAttention(nn.Module): rotary_dim=self.head_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, ) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 65092584edced..e879599ad3ead 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -119,9 +119,6 @@ class StablelmAttention(nn.Module): self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.partial_rotary_factor = getattr( - config, "rope_pct", getattr(config, "partial_rotary_factor", 1) - ) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -154,7 +151,6 @@ class StablelmAttention(nn.Module): rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, rope_parameters=self.config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1075bc2449b3c..f926b523afdfa 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -25,6 +25,7 @@ from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs +from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.utils import parse_safetensors_file_metadata @@ -304,7 +305,8 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No def patch_rope_parameters(config: PretrainedConfig) -> None: """Provide backwards compatibility for RoPE.""" - rope_theta = getattr(config, "rope_theta", None) + rope_theta_names = ("rope_theta", "rotary_emb_base") + rope_theta = getattr_iter(config, rope_theta_names, None) if Version(version("transformers")) < Version("5.0.0.dev0"): # Transformers v4 installed, legacy config fields may be present if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: @@ -313,6 +315,12 @@ def patch_rope_parameters(config: PretrainedConfig) -> None: if not hasattr(config, "rope_parameters"): config.rope_parameters = {"rope_type": "default"} config.rope_parameters["rope_theta"] = rope_theta + partial_rotary_factor_names = ("partial_rotary_factor", "rotary_pct") + partial_rotary_factor = getattr_iter(config, partial_rotary_factor_names, None) + if partial_rotary_factor is not None: + if not hasattr(config, "rope_parameters"): + config.rope_parameters = {"rope_type": "default"} + config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor elif rope_theta is not None or hasattr(config, "rope_parameters"): # Transformers v5 installed config.standardize_rope_params() diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index d112c71d7d20b..62f52703029b7 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -89,9 +89,14 @@ class NemotronConfig(PretrainedConfig): tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_parameters (`dict`, *optional*): - The parameters of the RoPE embeddings. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Percentage of the query and keys which will have rotary embedding. + The parameters of the RoPE embeddings. Expected contents: + `rope_theta` (`float`): The base period of the RoPE embeddings. + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', + 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the + original RoPE implementation. + `partial_rotary_factor` (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. @@ -133,7 +138,6 @@ class NemotronConfig(PretrainedConfig): eos_token_id=3, tie_word_embeddings=False, rope_parameters=None, - partial_rotary_factor=0.5, attention_bias=False, attention_dropout=0.0, mlp_bias=False, @@ -165,14 +169,16 @@ class NemotronConfig(PretrainedConfig): rope_theta = kwargs.pop("rope_theta", 10000.0) if "rope_theta" not in rope_parameters: rope_parameters["rope_theta"] = rope_theta - self.rope_parameters = rope_parameters # for backward compatibility partial_rotary_factor = ( kwargs.get("rope_percent") or kwargs.get("rope_percentage") - or partial_rotary_factor + or kwargs.get("partial_rotary_factor") + or 0.5 ) - self.partial_rotary_factor = partial_rotary_factor + if "partial_rotary_factor" not in rope_parameters: + rope_parameters["partial_rotary_factor"] = partial_rotary_factor + self.rope_parameters = rope_parameters self._rope_parameters_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py index fd36b49245f56..8230a18343c5e 100644 --- a/vllm/transformers_utils/configs/qwen3_next.py +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -103,8 +103,8 @@ class Qwen3NextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - partial_rotary_factor (`float`, *optional*, defaults to 0.25): - Percentage of the query and keys which will have rotary embedding. + `partial_rotary_factor` (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -198,7 +198,6 @@ class Qwen3NextConfig(PretrainedConfig): use_cache=True, tie_word_embeddings=False, rope_parameters=None, - partial_rotary_factor=0.25, attention_bias=False, attention_dropout=0.0, head_dim=256, @@ -239,6 +238,9 @@ class Qwen3NextConfig(PretrainedConfig): rope_theta = kwargs.pop("rope_theta", 10000.0) if "rope_theta" not in rope_parameters: rope_parameters["rope_theta"] = rope_theta + partial_rotary_factor = kwargs.pop("partial_rotary_factor", 0.25) + if "partial_rotary_factor" not in rope_parameters: + rope_parameters["partial_rotary_factor"] = partial_rotary_factor self.rope_parameters = rope_parameters self.partial_rotary_factor = partial_rotary_factor self.attention_bias = attention_bias