[Model] Add support for YARN in NemotronNAS models (#18427)

Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
Naveassaf 2025-05-26 13:31:49 +03:00 committed by GitHub
parent 5a2c76cbe1
commit 6d68030f1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 16 deletions

View File

@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
self._init_rotary_emb(config,
rope_scaling=rope_scaling,
quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(self, config: LlamaConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and self.config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
class LlamaDecoderLayer(nn.Module):

View File

@ -23,18 +23,20 @@
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
return n + k - (n % k)
class DeciLMAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(config, hidden_size, num_heads, num_kv_heads,
rope_theta, rope_scaling, max_position_embeddings,
quant_config, bias, bias_o_proj, cache_config, prefix,
attn_type)
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
# Enables YARN for Mistral and LLaMA4 derivatives.
is_neox_style = True
if hasattr(config, "position_embedding_type"):
is_neox_style = config.position_embedding_type not in [
"mistral_yarn", "rope_llama4"
]
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor)
class DeciLMDecoderLayer(nn.Module):
def __init__(
@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
if not self._is_no_op_attention:
num_kv_heads = (config.num_attention_heads //
block_config.attention.n_heads_in_group)
self.self_attn = LlamaAttention(
self.self_attn = DeciLMAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,