Let GraniteMoeAttention use YaRN (#21174)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-07-18 14:52:52 +02:00 committed by GitHub
parent 45badd05d0
commit ed8cbfedf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View File

@ -24,7 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
@ -113,6 +113,7 @@ class GraniteMoeAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attention_multiplier: Optional[float] = None, attention_multiplier: Optional[float] = None,
@ -163,6 +164,7 @@ class GraniteMoeAttention(nn.Module):
max_position=max_position, max_position=max_position,
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
rope_scaling=rope_scaling,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
@ -198,12 +200,14 @@ class GraniteMoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention( self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",

View File

@ -81,12 +81,14 @@ class GraniteMoeSharedDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention( self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",