Add truncate arg to yarn to match openai implementation of gpt-oss (#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Anna Shors 2025-11-20 02:53:50 -08:00 committed by GitHub
parent 66483a9d00
commit 6eb745d9bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 7 deletions

View File

@ -197,6 +197,7 @@ def get_rope(
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_parameters:

View File

@ -117,13 +117,13 @@ def yarn_find_correction_range(
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> tuple[int, int]:
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
truncate: bool = True,
) -> tuple[float | int, float | int]:
low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1) # Clamp values just in case

View File

@ -28,12 +28,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
beta_fast: int = 32,
beta_slow: int = 1,
apply_yarn_scaling: bool = True,
truncate: bool = True,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.truncate = truncate
# Get n-d magnitude scaling corrected for interpolation
self.mscale = (
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
@ -57,6 +59,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.rotary_dim,
self.base,
self.max_position_embeddings,
self.truncate,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (

View File

@ -78,6 +78,7 @@ class OAIAttention(nn.Module):
],
"beta_fast": config.rope_parameters["beta_fast"],
"beta_slow": config.rope_parameters["beta_slow"],
"truncate": config.rope_parameters.get("truncate", True),
},
is_neox_style=True,
)