diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index ae8a7d93b50e4..152d9401b8e94 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -197,6 +197,7 @@ def get_rope( "beta_fast", "beta_slow", "apply_yarn_scaling", + "truncate", ) } if "mrope_section" in rope_parameters: diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 196533b617959..13f8d15cc0f72 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -117,13 +117,13 @@ def yarn_find_correction_range( dim: int, base: float = 10000, max_position_embeddings: int = 2048, -) -> tuple[int, int]: - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) + truncate: bool = True, +) -> tuple[float | int, float | int]: + low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) return max(low, 0), min(high, dim - 1) # Clamp values just in case diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py index ff46ad74b302e..f01ca1e231211 100644 --- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -28,12 +28,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): beta_fast: int = 32, beta_slow: int = 1, apply_yarn_scaling: bool = True, + truncate: bool = True, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.truncate = truncate # Get n-d magnitude scaling corrected for interpolation self.mscale = ( float(yarn_get_mscale(self.scaling_factor) * attn_factor) @@ -57,6 +59,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): self.rotary_dim, self.base, self.max_position_embeddings, + self.truncate, ) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = ( diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 25048330f7974..8835acb8ec65c 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -78,6 +78,7 @@ class OAIAttention(nn.Module): ], "beta_fast": config.rope_parameters["beta_fast"], "beta_slow": config.rope_parameters["beta_slow"], + "truncate": config.rope_parameters.get("truncate", True), }, is_neox_style=True, )