From ed8cbfedf84f1b1fc1d3eadf3622d9903e906cb0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 18 Jul 2025 14:52:52 +0200 Subject: [PATCH] Let GraniteMoeAttention use YaRN (#21174) Signed-off-by: Thomas Parnell --- vllm/model_executor/models/granitemoe.py | 6 +++++- vllm/model_executor/models/granitemoeshared.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 142b0e967295..7d31854dce8d 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only GraniteMoe model.""" from collections.abc import Iterable -from typing import Optional +from typing import Any, Optional import torch from torch import nn @@ -113,6 +113,7 @@ class GraniteMoeAttention(nn.Module): num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, attention_multiplier: Optional[float] = None, @@ -163,6 +164,7 @@ class GraniteMoeAttention(nn.Module): max_position=max_position, base=int(self.rope_theta), is_neox_style=True, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -198,12 +200,14 @@ class GraniteMoeDecoderLayer(nn.Module): self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) self.self_attn = GraniteMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + rope_scaling=rope_scaling, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 7303f4853782..1e2e8544179c 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -81,12 +81,14 @@ class GraniteMoeSharedDecoderLayer(nn.Module): self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) self.self_attn = GraniteMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + rope_scaling=rope_scaling, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn",