[Model] Add support for YARN in NemotronNAS models (#18427)

Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
Naveassaf 2025-05-26 13:31:49 +03:00 committed by GitHub
parent 5a2c76cbe1
commit 6d68030f1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 16 deletions

View File

@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
is_neox_style = True self._init_rotary_emb(config,
is_gguf = quant_config and quant_config.get_name() == "gguf" rope_scaling=rope_scaling,
if is_gguf and config.model_type == "llama": quant_config=quant_config)
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
if hasattr(config, "interleaved_sliding_window"): if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window interleaved_sliding_window = config.interleaved_sliding_window
@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def _init_rotary_emb(self, config: LlamaConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and self.config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):

View File

@ -23,18 +23,20 @@
# limitations under the License. # limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights.""" """Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
return n + k - (n % k) return n + k - (n % k)
class DeciLMAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(config, hidden_size, num_heads, num_kv_heads,
rope_theta, rope_scaling, max_position_embeddings,
quant_config, bias, bias_o_proj, cache_config, prefix,
attn_type)
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
# Enables YARN for Mistral and LLaMA4 derivatives.
is_neox_style = True
if hasattr(config, "position_embedding_type"):
is_neox_style = config.position_embedding_type not in [
"mistral_yarn", "rope_llama4"
]
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor)
class DeciLMDecoderLayer(nn.Module): class DeciLMDecoderLayer(nn.Module):
def __init__( def __init__(
@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
if not self._is_no_op_attention: if not self._is_no_op_attention:
num_kv_heads = (config.num_attention_heads // num_kv_heads = (config.num_attention_heads //
block_config.attention.n_heads_in_group) block_config.attention.n_heads_in_group)
self.self_attn = LlamaAttention( self.self_attn = DeciLMAttention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,