mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 22:41:29 +08:00
Add truncate arg to yarn to match openai implementation of gpt-oss (#28244)
Signed-off-by: ashors1 <ashors@nvidia.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
66483a9d00
commit
6eb745d9bd
@ -197,6 +197,7 @@ def get_rope(
|
|||||||
"beta_fast",
|
"beta_fast",
|
||||||
"beta_slow",
|
"beta_slow",
|
||||||
"apply_yarn_scaling",
|
"apply_yarn_scaling",
|
||||||
|
"truncate",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if "mrope_section" in rope_parameters:
|
if "mrope_section" in rope_parameters:
|
||||||
|
|||||||
@ -117,13 +117,13 @@ def yarn_find_correction_range(
|
|||||||
dim: int,
|
dim: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
max_position_embeddings: int = 2048,
|
max_position_embeddings: int = 2048,
|
||||||
) -> tuple[int, int]:
|
truncate: bool = True,
|
||||||
low = math.floor(
|
) -> tuple[float | int, float | int]:
|
||||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||||
)
|
high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||||
high = math.ceil(
|
if truncate:
|
||||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
low = math.floor(low)
|
||||||
)
|
high = math.ceil(high)
|
||||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,12 +28,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
beta_fast: int = 32,
|
beta_fast: int = 32,
|
||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
apply_yarn_scaling: bool = True,
|
apply_yarn_scaling: bool = True,
|
||||||
|
truncate: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
|
self.truncate = truncate
|
||||||
# Get n-d magnitude scaling corrected for interpolation
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
self.mscale = (
|
self.mscale = (
|
||||||
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
@ -57,6 +59,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
self.rotary_dim,
|
self.rotary_dim,
|
||||||
self.base,
|
self.base,
|
||||||
self.max_position_embeddings,
|
self.max_position_embeddings,
|
||||||
|
self.truncate,
|
||||||
)
|
)
|
||||||
# Get n-d rotational scaling corrected for extrapolation
|
# Get n-d rotational scaling corrected for extrapolation
|
||||||
inv_freq_mask = (
|
inv_freq_mask = (
|
||||||
|
|||||||
@ -78,6 +78,7 @@ class OAIAttention(nn.Module):
|
|||||||
],
|
],
|
||||||
"beta_fast": config.rope_parameters["beta_fast"],
|
"beta_fast": config.rope_parameters["beta_fast"],
|
||||||
"beta_slow": config.rope_parameters["beta_slow"],
|
"beta_slow": config.rope_parameters["beta_slow"],
|
||||||
|
"truncate": config.rope_parameters.get("truncate", True),
|
||||||
},
|
},
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user