[BugFix] Fix RoPE error in Llama 3.1 (#6693)

This commit is contained in:
Woosuk Kwon 2024-07-23 09:46:05 -07:00 committed by GitHub
parent 461089a21a
commit a112a84aad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 30 deletions

View File

@ -154,15 +154,6 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})
if (not self.disable_sliding_window if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None): and self.hf_text_config.sliding_window is not None):
@ -1492,11 +1483,18 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")
# The correct one should be "longrope", kept "su" here # The correct one should be "longrope", kept "su" here
# to be backward compatible # to be backward compatible
if rope_scaling is not None and rope_scaling["type"] not in { if rope_type not in ("su", "longrope", "llama3"):
"su", "longrope", "extended"
}:
if disable_sliding_window: if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling # TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed. # with sliding window to see if this case should be allowed.
@ -1504,9 +1502,10 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models " "Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can " "with rope_scaling. Please raise an issue so we can "
"investigate.") "investigate.")
assert "factor" in rope_scaling assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling[ derived_max_model_len = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor

View File

@ -794,12 +794,13 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling["type"] scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope", "extended"}: if scaling_type not in {"su", "longrope", "llama3"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if scaling_type == "extended": if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)