mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Model] Add support for YARN in NemotronNAS models (#18427)
Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
parent
5a2c76cbe1
commit
6d68030f1c
@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
is_neox_style = True
|
self._init_rotary_emb(config,
|
||||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
rope_scaling=rope_scaling,
|
||||||
if is_gguf and config.model_type == "llama":
|
quant_config=quant_config)
|
||||||
is_neox_style = False
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
|
||||||
self.head_dim,
|
|
||||||
rotary_dim=self.head_dim,
|
|
||||||
max_position=max_position_embeddings,
|
|
||||||
base=rope_theta,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
is_neox_style=is_neox_style,
|
|
||||||
partial_rotary_factor=self.partial_rotary_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(config, "interleaved_sliding_window"):
|
if hasattr(config, "interleaved_sliding_window"):
|
||||||
interleaved_sliding_window = config.interleaved_sliding_window
|
interleaved_sliding_window = config.interleaved_sliding_window
|
||||||
@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
|
|||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _init_rotary_emb(self, config: LlamaConfig,
|
||||||
|
rope_scaling: Optional[dict[str, Any]],
|
||||||
|
quant_config: Optional[QuantizationConfig]) -> None:
|
||||||
|
is_neox_style = True
|
||||||
|
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||||
|
if is_gguf and self.config.model_type == "llama":
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
partial_rotary_factor=self.partial_rotary_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
|||||||
@ -23,18 +23,20 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only deci model compatible with HuggingFace weights."""
|
"""Inference-only deci model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.attention import AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
|
|||||||
return n + k - (n % k)
|
return n + k - (n % k)
|
||||||
|
|
||||||
|
|
||||||
|
class DeciLMAttention(LlamaAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
bias_o_proj: bool = False,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, hidden_size, num_heads, num_kv_heads,
|
||||||
|
rope_theta, rope_scaling, max_position_embeddings,
|
||||||
|
quant_config, bias, bias_o_proj, cache_config, prefix,
|
||||||
|
attn_type)
|
||||||
|
|
||||||
|
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
|
||||||
|
quant_config: Optional[QuantizationConfig]) -> None:
|
||||||
|
# Enables YARN for Mistral and LLaMA4 derivatives.
|
||||||
|
is_neox_style = True
|
||||||
|
if hasattr(config, "position_embedding_type"):
|
||||||
|
is_neox_style = config.position_embedding_type not in [
|
||||||
|
"mistral_yarn", "rope_llama4"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
partial_rotary_factor=self.partial_rotary_factor)
|
||||||
|
|
||||||
|
|
||||||
class DeciLMDecoderLayer(nn.Module):
|
class DeciLMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
|
|||||||
if not self._is_no_op_attention:
|
if not self._is_no_op_attention:
|
||||||
num_kv_heads = (config.num_attention_heads //
|
num_kv_heads = (config.num_attention_heads //
|
||||||
block_config.attention.n_heads_in_group)
|
block_config.attention.n_heads_in_group)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = DeciLMAttention(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user