mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:45:24 +08:00
[BugFix] Fix RoPE error in Llama 3.1 (#6693)
This commit is contained in:
parent
461089a21a
commit
a112a84aad
@ -154,15 +154,6 @@ class ModelConfig:
|
|||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
|
|
||||||
and getattr(self.hf_config, "rope_scaling", None) is None):
|
|
||||||
# Note(simon): this is a special case for a model that doesn't
|
|
||||||
# supply rope_scaling. We should remove this once the model is
|
|
||||||
# updated.
|
|
||||||
self.hf_config.update({"rope_scaling": {
|
|
||||||
"type": "extended",
|
|
||||||
}})
|
|
||||||
|
|
||||||
if (not self.disable_sliding_window
|
if (not self.disable_sliding_window
|
||||||
and self.hf_text_config.model_type == "gemma2"
|
and self.hf_text_config.model_type == "gemma2"
|
||||||
and self.hf_text_config.sliding_window is not None):
|
and self.hf_text_config.sliding_window is not None):
|
||||||
@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len = default_max_len
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
# The correct one should be "longrope", kept "su" here
|
if rope_scaling is not None:
|
||||||
# to be backward compatible
|
if "type" in rope_scaling:
|
||||||
if rope_scaling is not None and rope_scaling["type"] not in {
|
rope_type = rope_scaling["type"]
|
||||||
"su", "longrope", "extended"
|
elif "rope_type" in rope_scaling:
|
||||||
}:
|
rope_type = rope_scaling["rope_type"]
|
||||||
if disable_sliding_window:
|
else:
|
||||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
raise ValueError(
|
||||||
# with sliding window to see if this case should be allowed.
|
"rope_scaling must have a 'type' or 'rope_type' key.")
|
||||||
raise NotImplementedError(
|
|
||||||
"Disabling sliding window is not supported for models "
|
# The correct one should be "longrope", kept "su" here
|
||||||
"with rope_scaling. Please raise an issue so we can "
|
# to be backward compatible
|
||||||
"investigate.")
|
if rope_type not in ("su", "longrope", "llama3"):
|
||||||
assert "factor" in rope_scaling
|
if disable_sliding_window:
|
||||||
scaling_factor = rope_scaling["factor"]
|
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||||
if rope_scaling["type"] == "yarn":
|
# with sliding window to see if this case should be allowed.
|
||||||
derived_max_model_len = rope_scaling[
|
raise NotImplementedError(
|
||||||
"original_max_position_embeddings"]
|
"Disabling sliding window is not supported for models "
|
||||||
derived_max_model_len *= scaling_factor
|
"with rope_scaling. Please raise an issue so we can "
|
||||||
|
"investigate.")
|
||||||
|
|
||||||
|
assert "factor" in rope_scaling
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_type == "yarn":
|
||||||
|
derived_max_model_len = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
# If the user specified a max length, make sure it is smaller than the
|
# If the user specified a max length, make sure it is smaller than the
|
||||||
# derived length from the HF model config.
|
# derived length from the HF model config.
|
||||||
|
|||||||
@ -794,12 +794,13 @@ def get_rope(
|
|||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
else:
|
else:
|
||||||
scaling_type = rope_scaling["type"]
|
scaling_type = rope_scaling[
|
||||||
|
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
|
||||||
# The correct one should be "longrope" but keep "su" here
|
# The correct one should be "longrope" but keep "su" here
|
||||||
# for backward compatible
|
# for backward compatible
|
||||||
if scaling_type not in {"su", "longrope", "extended"}:
|
if scaling_type not in {"su", "longrope", "llama3"}:
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if scaling_type == "extended":
|
if scaling_type == "llama3":
|
||||||
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user