mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Model] Add support for YARN in NemotronNAS models (#18427)
Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
parent
5a2c76cbe1
commit
6d68030f1c
@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type == "llama":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
self._init_rotary_emb(config,
|
||||
rope_scaling=rope_scaling,
|
||||
quant_config=quant_config)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(self, config: LlamaConfig,
|
||||
rope_scaling: Optional[dict[str, Any]],
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and self.config.model_type == "llama":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@ -23,18 +23,20 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only deci model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
|
||||
return n + k - (n % k)
|
||||
|
||||
|
||||
class DeciLMAttention(LlamaAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__(config, hidden_size, num_heads, num_kv_heads,
|
||||
rope_theta, rope_scaling, max_position_embeddings,
|
||||
quant_config, bias, bias_o_proj, cache_config, prefix,
|
||||
attn_type)
|
||||
|
||||
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
# Enables YARN for Mistral and LLaMA4 derivatives.
|
||||
is_neox_style = True
|
||||
if hasattr(config, "position_embedding_type"):
|
||||
is_neox_style = config.position_embedding_type not in [
|
||||
"mistral_yarn", "rope_llama4"
|
||||
]
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
partial_rotary_factor=self.partial_rotary_factor)
|
||||
|
||||
|
||||
class DeciLMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
|
||||
if not self._is_no_op_attention:
|
||||
num_kv_heads = (config.num_attention_heads //
|
||||
block_config.attention.n_heads_in_group)
|
||||
self.self_attn = LlamaAttention(
|
||||
self.self_attn = DeciLMAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user