[Model] NemotronH support (#18863)

Signed-off-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com>
Co-authored-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com>
This commit is contained in:
Luis Vega 2025-06-05 14:29:28 -07:00 committed by GitHub
parent 87360308b7
commit cb6d572e85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 829 additions and 0 deletions

View File

@ -346,6 +346,7 @@ Specified using `--task generate`.
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ |

View File

@ -212,6 +212,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"NemotronHForCausalLM": _HfExamplesInfo("nvidia/Nemotron-H-8B-Base-8K",
trust_remote_code=True),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),

View File

@ -0,0 +1,565 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP,
SupportsQuant,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType
class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
) -> None:
super().__init__()
self.up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size],
bias=bias,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
)
self.act_fn = ReLUSquaredActivation()
def forward(self, x: torch.Tensor):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMLP(config,
quant_config=quant_config,
bias=config.mlp_bias)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states)
return hidden_states, residual
class NemotronHMambaDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.ssm_state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.expand * config.hidden_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.n_groups,
num_heads=config.mamba_num_heads,
head_dim=config.mamba_head_dim,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba2_metadata)
return hidden_states, residual
class NemotronHAttention(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class NemotronHAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.mixer = NemotronHAttention(
config,
layer_idx,
cache_config,
quant_config,
prefix,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states=hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"M": NemotronHMambaDecoderLayer,
"-": NemotronHMLPDecoderLayer,
"*": NemotronHAttentionDecoderLayer,
}
class NemotronHModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: NemotronHConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.hybrid_override_pattern[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
len(config.hybrid_override_pattern),
get_layer,
prefix=f"{prefix}.layers")
self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = None
num_non_mamba_layers = 0
for i in range(len(self.layers)):
layer = self.layers[i]
layer_mamba_cache_params = None
if isinstance(layer, NemotronHMambaDecoderLayer):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
num_non_mamba_layers += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
attb_params_mapping = {
"q_proj": "q",
"k_proj": "k",
"v_proj": "v",
}
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "embeddings" in name:
name = name.replace("embeddings", "embed_tokens")
if "A_log" in name:
name = name.replace("A_log", "A")
loaded_weight = loaded_weight.to(torch.float32)
if "D" in name:
loaded_weight = loaded_weight.to(torch.float32)
if "dt_bias" in name:
loaded_weight = loaded_weight.to(torch.float32)
# load attn params
if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
weight_name = next(proj
for proj in ["q_proj", "k_proj", "v_proj"]
if proj in name)
name = name.replace(weight_name, "qkv_proj")
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight,
attb_params_mapping[weight_name])
# load other params
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"NemotronH currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = NemotronHModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = self.config.expand * hidden_size
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (
self.config.n_groups +
extra_groups_for_head_shards(self.config.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.ssm_state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.conv_kernel - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_num_heads, world_size),
self.config.mamba_head_dim,
self.config.ssm_state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
# update name in weights before passing to loader
updated_weights = []
for name, loaded_weight in weights:
name = name.replace("backbone", "model")
updated_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self)
return loader.load_weights(updated_weights)

View File

@ -92,6 +92,7 @@ _TEXT_GENERATION_MODELS = {
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),

View File

@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
@ -50,6 +51,7 @@ __all__ = [
"MoonViTConfig",
"KimiVLConfig",
"NemotronConfig",
"NemotronHConfig",
"NVLM_D_Config",
"OvisConfig",
"SkyworkR1VChatConfig",

View File

@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NemotronH model configuration"""
import regex as re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NemotronHConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to
that of the NemotronH-v0.1 model.
Args:
vocab_size (`int`, *optional*, defaults to 131072):
Vocabulary size of the NemotronH model. Defines the number of
different tokens that can be represented by the `inputs_ids`
passed when calling [`NemotronHModel`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be
tied. Note that this is only relevant if the model has a output
word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 21504):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 52):
Number of hidden layers in the Transformer encoder.
hybrid_override_pattern (`str`, *optional*, defaults to
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
The pattern of the hybrid model. The pattern is a string of
characters where each character represents
M: Mamba2, *: Attention, -: MLP
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer encoder.
attention_head_dim (`int`, *optional*, defaults to 128):
Dimension of each attention head.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
will use Multi Query Attention (MQA) otherwise GQA is used.
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
The non-linear activation function in the MLP layers.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in attention layers.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in MLP layers.
use_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
Whether or not residuals should be in `float32`. If set to `False`
residuals will keep the same `dtype` as the rest of the model.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`,
all logits will be calculated. If an integer value, only last
`num_logits_to_keep` logits will be calculated.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*, defaults to None):
Sliding window attention window size.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used
with.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels.
These are available only if `mamba-ssm` and `causal-conv1d`
are installed, and the mamba modules are running on a CUDA device.
ssm_state_size (`int`, *optional*, defaults to 128):
The dimension of the mamba state space latents.
mamba_num_heads (`int`, *optional*, defaults to 128):
Number of heads in Mamba layers.
mamba_n_groups (`int`, *optional*, defaults to 8):
Number of groups in Mamba layers.
mamba_head_dim (`int`, *optional*, defaults to 64):
Dimension of each Mamba head.
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor used to determine the mamba intermediate size.
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
The non-linear activation function in the Mamba layers.
mamba_dt_min (`float`, *optional*, defaults to 0.001):
Minimum value for the time step in Mamba.
mamba_dt_max (`float`, *optional*, defaults to 0.1):
Maximum value for the time step in Mamba.
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
Limits for the time step in Mamba.
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
Floor value for time step initialization in Mamba.
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the convolution layer of the mamba mixer
block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the input and output projections of the
mamba mixer block.
mamba_chunk_size (`int`, *optional*, defaults to 256):
Size of chunks for Mamba processing.
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
Whether to rescale the pre-normalization residual connections.
"""
model_type = "nemotron_h"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=21504,
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_attention_heads=32,
attention_head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
mlp_hidden_act="relu2",
attention_bias=False,
mlp_bias=False,
use_bias=False,
initializer_range=0.02, # nemo: init_method_std
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
residual_in_fp32=False, # Megatron Core default value
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=4096,
attention_dropout=0.0,
hidden_dropout=0.0, # * ADDED
use_mamba_kernels=True,
ssm_state_size=128, # mamba_state_size
mamba_num_heads=128,
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
mamba_head_dim=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_hidden_act="silu",
mamba_dt_min=0.001,
mamba_dt_max=0.1,
mamba_dt_limit=(0.0, float("inf")),
mamba_dt_init_floor=1e-4,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_chunk_size=256,
rescale_prenorm_residual=True,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
# Validate hybrid_override_pattern
# M: Mamba2, *: Attention, -: MLP
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
"hybrid_override_pattern must have same length as "
"num_hidden_layers")
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
"hybrid_override_pattern must only contain characters "
"'M', '*', or '-'")
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_hidden_act = mlp_hidden_act
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias
self.use_bias = use_bias
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.use_mamba_kernels = use_mamba_kernels
self.n_groups = mamba_n_groups
self.mamba_head_dim = mamba_head_dim
self.ssm_state_size = ssm_state_size
self.mamba_num_heads = mamba_num_heads
self.conv_kernel = mamba_d_conv
self.expand = mamba_expand
self.mamba_hidden_act = mamba_hidden_act
self.time_step_min = mamba_dt_min
self.time_step_max = mamba_dt_max
self.time_step_limit = mamba_dt_limit
self.time_step_floor = mamba_dt_init_floor
self.use_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def layers_block_type(self):
return [
"mamba" if self.hybrid_override_pattern[i] == "M" else
"attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
for i in range(self.num_hidden_layers)
]