diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 64187c97cab7..56c165f9c041 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -191,9 +191,16 @@ def get_rope( k: v for k, v in rope_scaling.items() if k - in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "apply_yarn_scaling", + ) } if "mrope_section" in rope_scaling: + extra_kwargs.pop("apply_yarn_scaling", None) rotary_emb = MRotaryEmbedding( head_size, rotary_dim, diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py index 93c92e7801e1..ff46ad74b302 100644 --- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -27,6 +27,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, + apply_yarn_scaling: bool = True, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -34,7 +35,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation - self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) + self.mscale = ( + float(yarn_get_mscale(self.scaling_factor) * attn_factor) + if apply_yarn_scaling + else float(attn_factor) + ) super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7cc908e52c88..0a08bd376bad 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -160,6 +160,14 @@ class LlamaAttention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + self.do_llama_4_scaling = llama_4_scaling_config is not None + if self.do_llama_4_scaling: + self.llama_4_scaling_original_max_position_embeddings = ( + llama_4_scaling_config["original_max_position_embeddings"] + ) + self.llama_4_scaling_beta = llama_4_scaling_config["beta"] + self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, @@ -221,6 +229,17 @@ class LlamaAttention(nn.Module): prefix=f"{prefix}.attn", ) + def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + # Llama4 scaling + scaling = 1 + self.llama_4_scaling_beta * torch.log( + 1 + + torch.floor( + positions / self.llama_4_scaling_original_max_position_embeddings + ) + ) + # Broadcast over head_dim + return scaling.unsqueeze(-1) + def forward( self, positions: torch.Tensor, @@ -229,6 +248,9 @@ class LlamaAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) + if self.do_llama_4_scaling: + attn_scale = self._get_llama_4_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index d5bf79e01f95..c6f04febe37e 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -24,6 +24,18 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig if bool(config_dict.get("yarn")): config_dict = _remap_mistral_yarn_args(config_dict) + if bool(config_dict.get("llama_4_scaling")): + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + is_vision = (config_dict.get("multimodal") or {}).get( "vision_encoder_args" ) or config_dict.get("vision_encoder") @@ -66,19 +78,24 @@ def _remap_mistral_vision_args(config: dict) -> dict: def _remap_mistral_yarn_args(config: dict) -> dict: - # Direct remaps: yarn.X -> rope_scaling.Y - # Source keys are from mistral.model.args.YarnArgs - _map = { + yarn_config_map = { + "factor": "factor", + "original_max_position_embeddings": "original_max_position_embeddings", "beta": "beta_fast", "alpha": "beta_slow", + "apply_scale": "apply_yarn_scaling", } yarn_config = config.get("yarn") or {} - renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()} config["rope_scaling"] = { "rope_type": "yarn", - "mscale_all_dim": 1, # We hardcoded this to 1 - **renamed_yarn_config, + "mscale_all_dim": 1, } + for old_name, new_name in yarn_config_map.items(): + if old_name in yarn_config: + config["rope_scaling"][new_name] = yarn_config.pop(old_name) + + assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" + return config