mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
Add llama 4 scaling support (#28145)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
parent
5e0c1fe69c
commit
7a8375f8a0
@ -191,9 +191,16 @@ def get_rope(
|
|||||||
k: v
|
k: v
|
||||||
for k, v in rope_scaling.items()
|
for k, v in rope_scaling.items()
|
||||||
if k
|
if k
|
||||||
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
|
in (
|
||||||
|
"extrapolation_factor",
|
||||||
|
"attn_factor",
|
||||||
|
"beta_fast",
|
||||||
|
"beta_slow",
|
||||||
|
"apply_yarn_scaling",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if "mrope_section" in rope_scaling:
|
if "mrope_section" in rope_scaling:
|
||||||
|
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||||
rotary_emb = MRotaryEmbedding(
|
rotary_emb = MRotaryEmbedding(
|
||||||
head_size,
|
head_size,
|
||||||
rotary_dim,
|
rotary_dim,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
beta_fast: int = 32,
|
beta_fast: int = 32,
|
||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
|
apply_yarn_scaling: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
@ -34,7 +35,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
# Get n-d magnitude scaling corrected for interpolation
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
self.mscale = (
|
||||||
|
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
|
if apply_yarn_scaling
|
||||||
|
else float(attn_factor)
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
)
|
)
|
||||||
|
|||||||
@ -160,6 +160,14 @@ class LlamaAttention(nn.Module):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
|
||||||
|
self.do_llama_4_scaling = llama_4_scaling_config is not None
|
||||||
|
if self.do_llama_4_scaling:
|
||||||
|
self.llama_4_scaling_original_max_position_embeddings = (
|
||||||
|
llama_4_scaling_config["original_max_position_embeddings"]
|
||||||
|
)
|
||||||
|
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
@ -221,6 +229,17 @@ class LlamaAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Llama4 scaling
|
||||||
|
scaling = 1 + self.llama_4_scaling_beta * torch.log(
|
||||||
|
1
|
||||||
|
+ torch.floor(
|
||||||
|
positions / self.llama_4_scaling_original_max_position_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Broadcast over head_dim
|
||||||
|
return scaling.unsqueeze(-1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -229,6 +248,9 @@ class LlamaAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
if self.do_llama_4_scaling:
|
||||||
|
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||||
|
q = (q * attn_scale).to(q.dtype)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -24,6 +24,18 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
|
|||||||
if bool(config_dict.get("yarn")):
|
if bool(config_dict.get("yarn")):
|
||||||
config_dict = _remap_mistral_yarn_args(config_dict)
|
config_dict = _remap_mistral_yarn_args(config_dict)
|
||||||
|
|
||||||
|
if bool(config_dict.get("llama_4_scaling")):
|
||||||
|
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
||||||
|
assert all(
|
||||||
|
[
|
||||||
|
key in config_dict["llama_4_scaling"]
|
||||||
|
for key in llama_4_scaling_config_keys
|
||||||
|
]
|
||||||
|
), (
|
||||||
|
"llama_4_scaling config should define the keys: "
|
||||||
|
f"{','.join(llama_4_scaling_config_keys)}"
|
||||||
|
)
|
||||||
|
|
||||||
is_vision = (config_dict.get("multimodal") or {}).get(
|
is_vision = (config_dict.get("multimodal") or {}).get(
|
||||||
"vision_encoder_args"
|
"vision_encoder_args"
|
||||||
) or config_dict.get("vision_encoder")
|
) or config_dict.get("vision_encoder")
|
||||||
@ -66,19 +78,24 @@ def _remap_mistral_vision_args(config: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def _remap_mistral_yarn_args(config: dict) -> dict:
|
def _remap_mistral_yarn_args(config: dict) -> dict:
|
||||||
# Direct remaps: yarn.X -> rope_scaling.Y
|
yarn_config_map = {
|
||||||
# Source keys are from mistral.model.args.YarnArgs
|
"factor": "factor",
|
||||||
_map = {
|
"original_max_position_embeddings": "original_max_position_embeddings",
|
||||||
"beta": "beta_fast",
|
"beta": "beta_fast",
|
||||||
"alpha": "beta_slow",
|
"alpha": "beta_slow",
|
||||||
|
"apply_scale": "apply_yarn_scaling",
|
||||||
}
|
}
|
||||||
yarn_config = config.get("yarn") or {}
|
yarn_config = config.get("yarn") or {}
|
||||||
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
|
|
||||||
config["rope_scaling"] = {
|
config["rope_scaling"] = {
|
||||||
"rope_type": "yarn",
|
"rope_type": "yarn",
|
||||||
"mscale_all_dim": 1, # We hardcoded this to 1
|
"mscale_all_dim": 1,
|
||||||
**renamed_yarn_config,
|
|
||||||
}
|
}
|
||||||
|
for old_name, new_name in yarn_config_map.items():
|
||||||
|
if old_name in yarn_config:
|
||||||
|
config["rope_scaling"][new_name] = yarn_config.pop(old_name)
|
||||||
|
|
||||||
|
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user