[BugFix] Fix RoPE error in Llama 3.1 (#6693)

This commit is contained in:
Woosuk Kwon 2024-07-23 09:46:05 -07:00 committed by GitHub
parent 461089a21a
commit a112a84aad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 30 deletions

View File

@ -154,15 +154,6 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})
if (not self.disable_sliding_window if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None): and self.hf_text_config.sliding_window is not None):
@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
# The correct one should be "longrope", kept "su" here if rope_scaling is not None:
# to be backward compatible if "type" in rope_scaling:
if rope_scaling is not None and rope_scaling["type"] not in { rope_type = rope_scaling["type"]
"su", "longrope", "extended" elif "rope_type" in rope_scaling:
}: rope_type = rope_scaling["rope_type"]
if disable_sliding_window: else:
# TODO(robertgshaw): Find a model that supports rope_scaling raise ValueError(
# with sliding window to see if this case should be allowed. "rope_scaling must have a 'type' or 'rope_type' key.")
raise NotImplementedError(
"Disabling sliding window is not supported for models " # The correct one should be "longrope", kept "su" here
"with rope_scaling. Please raise an issue so we can " # to be backward compatible
"investigate.") if rope_type not in ("su", "longrope", "llama3"):
assert "factor" in rope_scaling if disable_sliding_window:
scaling_factor = rope_scaling["factor"] # TODO(robertgshaw): Find a model that supports rope_scaling
if rope_scaling["type"] == "yarn": # with sliding window to see if this case should be allowed.
derived_max_model_len = rope_scaling[ raise NotImplementedError(
"original_max_position_embeddings"] "Disabling sliding window is not supported for models "
derived_max_model_len *= scaling_factor "with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
# If the user specified a max length, make sure it is smaller than the # If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config. # derived length from the HF model config.

View File

@ -794,12 +794,13 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling["type"] scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope", "extended"}: if scaling_type not in {"su", "longrope", "llama3"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if scaling_type == "extended": if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)