[BugFix] Fix RoPE error in Llama 3.1 (#6693)

This commit is contained in:
Woosuk Kwon 2024-07-23 09:46:05 -07:00 committed by GitHub
parent 461089a21a
commit a112a84aad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 30 deletions

View File

@ -154,15 +154,6 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})
if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] not in {
"su", "longrope", "extended"
}:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.

View File

@ -794,12 +794,13 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope", "extended"}:
if scaling_type not in {"su", "longrope", "llama3"}:
scaling_factor = rope_scaling["factor"]
if scaling_type == "extended":
if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)