From cb6d572e85a34aec7b4409833bff12af28b0d28b Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Thu, 5 Jun 2025 14:29:28 -0700 Subject: [PATCH] [Model] NemotronH support (#18863) Signed-off-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> Co-authored-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> --- docs/models/supported_models.md | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/nemotron_h.py | 565 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/nemotron_h.py | 258 ++++++++ 6 files changed, 829 insertions(+) create mode 100644 vllm/model_executor/models/nemotron_h.py create mode 100644 vllm/transformers_utils/configs/nemotron_h.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 71414d2aad821..a8a6f3417e546 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -346,6 +346,7 @@ Specified using `--task generate`. | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 3e07dc0f322e1..e6543c197348c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -212,6 +212,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), + "NemotronHForCausalLM": _HfExamplesInfo("nvidia/Nemotron-H-8B-Base-8K", + trust_remote_code=True), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py new file mode 100644 index 0000000000000..2ef8d31150d5e --- /dev/null +++ b/vllm/model_executor/models/nemotron_h.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only NemotronH model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsLoRA, SupportsPP, + SupportsQuant, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import NemotronHConfig +from vllm.utils import LayerBlockType + + +class NemotronHMLP(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size], + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.act_fn = ReLUSquaredActivation() + + def forward(self, x: torch.Tensor): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class NemotronHMLPDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.mixer = NemotronHMLP(config, + quant_config=quant_config, + bias=config.mlp_bias) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states) + return hidden_states, residual + + +class NemotronHMambaDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.mixer = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.ssm_state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=config.expand * config.hidden_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + rms_norm_eps=config.rms_norm_eps, + activation=config.mamba_hidden_act, + quant_config=quant_config, + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, mamba_cache_params, + mamba2_metadata) + return hidden_states, residual + + +class NemotronHAttention(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class NemotronHAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.mixer = NemotronHAttention( + config, + layer_idx, + cache_config, + quant_config, + prefix, + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states=hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "M": NemotronHMambaDecoderLayer, + "-": NemotronHMLPDecoderLayer, + "*": NemotronHAttentionDecoderLayer, +} + + +class NemotronHModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: NemotronHConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.hybrid_override_pattern[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + len(config.hybrid_override_pattern), + get_layer, + prefix=f"{prefix}.layers") + self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size) + + self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + attn_metadata = get_forward_context().attn_metadata + + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_non_mamba_layers = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + layer_mamba_cache_params = None + if isinstance(layer, NemotronHMambaDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_non_mamba_layers) + else: + num_non_mamba_layers += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm_f(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + attb_params_mapping = { + "q_proj": "q", + "k_proj": "k", + "v_proj": "v", + } + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "embeddings" in name: + name = name.replace("embeddings", "embed_tokens") + + if "A_log" in name: + name = name.replace("A_log", "A") + loaded_weight = loaded_weight.to(torch.float32) + + if "D" in name: + loaded_weight = loaded_weight.to(torch.float32) + + if "dt_bias" in name: + loaded_weight = loaded_weight.to(torch.float32) + + # load attn params + if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]): + weight_name = next(proj + for proj in ["q_proj", "k_proj", "v_proj"] + if proj in name) + name = name.replace(weight_name, "qkv_proj") + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, + attb_params_mapping[weight_name]) + # load other params + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + return loaded_params + + +class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "NemotronH currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = NemotronHModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = ( + self.config.n_groups + + extra_groups_for_head_shards(self.config.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.ssm_state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.conv_kernel - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_num_heads, world_size), + self.config.mamba_head_dim, + self.config.ssm_state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + # update name in weights before passing to loader + updated_weights = [] + for name, loaded_weight in weights: + name = name.replace("backbone", "model") + updated_weights.append((name, loaded_weight)) + loader = AutoWeightsLoader(self) + return loader.load_weights(updated_weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 57d1b7c53ff60..e82e366380694 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -92,6 +92,7 @@ _TEXT_GENERATION_MODELS = { "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7edff455f2992..97a1b683a9b83 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig @@ -50,6 +51,7 @@ __all__ = [ "MoonViTConfig", "KimiVLConfig", "NemotronConfig", + "NemotronHConfig", "NVLM_D_Config", "OvisConfig", "SkyworkR1VChatConfig", diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py new file mode 100644 index 0000000000000..9fe75f2dfeea8 --- /dev/null +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NemotronH model configuration""" + +import regex as re +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NemotronHConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`NemotronHModel`]. It is used to instantiate a NemotronH model according + to the specified arguments, defining the model architecture. Instantiating + a configuration with the defaults will yield a similar configuration to + that of the NemotronH-v0.1 model. + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the NemotronH model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`NemotronHModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be + tied. Note that this is only relevant if the model has a output + word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 21504): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 52): + Number of hidden layers in the Transformer encoder. + hybrid_override_pattern (`str`, *optional*, defaults to + `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`): + The pattern of the hybrid model. The pattern is a string of + characters where each character represents + M: Mamba2, *: Attention, -: MLP + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 128): + Dimension of each attention head. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) otherwise GQA is used. + mlp_hidden_act (`str`, *optional*, defaults to "relu2"): + The non-linear activation function in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in MLP layers. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. If set to `False` + residuals will keep the same `dtype` as the rest of the model. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, + all logits will be calculated. If an integer value, only last + `num_logits_to_keep` logits will be calculated. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*, defaults to None): + Sliding window attention window size. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used + with. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden states. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. + These are available only if `mamba-ssm` and `causal-conv1d` + are installed, and the mamba modules are running on a CUDA device. + ssm_state_size (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents. + mamba_num_heads (`int`, *optional*, defaults to 128): + Number of heads in Mamba layers. + mamba_n_groups (`int`, *optional*, defaults to 8): + Number of groups in Mamba layers. + mamba_head_dim (`int`, *optional*, defaults to 64): + Dimension of each Mamba head. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor used to determine the mamba intermediate size. + mamba_hidden_act (`str`, *optional*, defaults to "silu"): + The non-linear activation function in the Mamba layers. + mamba_dt_min (`float`, *optional*, defaults to 0.001): + Minimum value for the time step in Mamba. + mamba_dt_max (`float`, *optional*, defaults to 0.1): + Maximum value for the time step in Mamba. + mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))): + Limits for the time step in Mamba. + mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4): + Floor value for time step initialization in Mamba. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the convolution layer of the mamba mixer + block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the input and output projections of the + mamba mixer block. + mamba_chunk_size (`int`, *optional*, defaults to 256): + Size of chunks for Mamba processing. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether to rescale the pre-normalization residual connections. + """ + + model_type = "nemotron_h" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=21504, + num_hidden_layers=52, + hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_attention_heads=32, + attention_head_dim=128, + num_key_value_heads=8, # nemo: num_query_groups + mlp_hidden_act="relu2", + attention_bias=False, + mlp_bias=False, + use_bias=False, + initializer_range=0.02, # nemo: init_method_std + layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon + residual_in_fp32=False, # Megatron Core default value + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=4096, + attention_dropout=0.0, + hidden_dropout=0.0, # * ADDED + use_mamba_kernels=True, + ssm_state_size=128, # mamba_state_size + mamba_num_heads=128, + mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads + mamba_head_dim=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_hidden_act="silu", + mamba_dt_min=0.001, + mamba_dt_max=0.1, + mamba_dt_limit=(0.0, float("inf")), + mamba_dt_init_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_chunk_size=256, + rescale_prenorm_residual=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.hybrid_override_pattern = hybrid_override_pattern + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + + # Validate hybrid_override_pattern + # M: Mamba2, *: Attention, -: MLP + assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( + "hybrid_override_pattern must have same length as " + "num_hidden_layers") + assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( + "hybrid_override_pattern must only contain characters " + "'M', '*', or '-'") + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_hidden_act = mlp_hidden_act + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.use_bias = use_bias + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.residual_in_fp32 = residual_in_fp32 + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.use_mamba_kernels = use_mamba_kernels + self.n_groups = mamba_n_groups + self.mamba_head_dim = mamba_head_dim + self.ssm_state_size = ssm_state_size + self.mamba_num_heads = mamba_num_heads + self.conv_kernel = mamba_d_conv + self.expand = mamba_expand + self.mamba_hidden_act = mamba_hidden_act + self.time_step_min = mamba_dt_min + self.time_step_max = mamba_dt_max + self.time_step_limit = mamba_dt_limit + self.time_step_floor = mamba_dt_init_floor + self.use_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.chunk_size = mamba_chunk_size + self.rescale_prenorm_residual = rescale_prenorm_residual + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + "mamba" if self.hybrid_override_pattern[i] == "M" else + "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + for i in range(self.num_hidden_layers) + ]