Add truncate arg to yarn to match openai implementation of gpt-oss (#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Anna Shors 2025-11-20 02:53:50 -08:00 committed by GitHub
parent 66483a9d00
commit 6eb745d9bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 7 deletions

View File

@ -197,6 +197,7 @@ def get_rope(
"beta_fast", "beta_fast",
"beta_slow", "beta_slow",
"apply_yarn_scaling", "apply_yarn_scaling",
"truncate",
) )
} }
if "mrope_section" in rope_parameters: if "mrope_section" in rope_parameters:

View File

@ -117,13 +117,13 @@ def yarn_find_correction_range(
dim: int, dim: int,
base: float = 10000, base: float = 10000,
max_position_embeddings: int = 2048, max_position_embeddings: int = 2048,
) -> tuple[int, int]: truncate: bool = True,
low = math.floor( ) -> tuple[float | int, float | int]:
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
) high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
high = math.ceil( if truncate:
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) low = math.floor(low)
) high = math.ceil(high)
return max(low, 0), min(high, dim - 1) # Clamp values just in case return max(low, 0), min(high, dim - 1) # Clamp values just in case

View File

@ -28,12 +28,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
beta_fast: int = 32, beta_fast: int = 32,
beta_slow: int = 1, beta_slow: int = 1,
apply_yarn_scaling: bool = True, apply_yarn_scaling: bool = True,
truncate: bool = True,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.truncate = truncate
# Get n-d magnitude scaling corrected for interpolation # Get n-d magnitude scaling corrected for interpolation
self.mscale = ( self.mscale = (
float(yarn_get_mscale(self.scaling_factor) * attn_factor) float(yarn_get_mscale(self.scaling_factor) * attn_factor)
@ -57,6 +59,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.rotary_dim, self.rotary_dim,
self.base, self.base,
self.max_position_embeddings, self.max_position_embeddings,
self.truncate,
) )
# Get n-d rotational scaling corrected for extrapolation # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = ( inv_freq_mask = (

View File

@ -78,6 +78,7 @@ class OAIAttention(nn.Module):
], ],
"beta_fast": config.rope_parameters["beta_fast"], "beta_fast": config.rope_parameters["beta_fast"],
"beta_slow": config.rope_parameters["beta_slow"], "beta_slow": config.rope_parameters["beta_slow"],
"truncate": config.rope_parameters.get("truncate", True),
}, },
is_neox_style=True, is_neox_style=True,
) )