[Model] Add MoE support for NemotronH (#25863)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-10-23 13:27:23 +03:00 committed by GitHub
parent 88afa11010
commit 61089465a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 413 additions and 39 deletions

View File

@ -823,6 +823,8 @@ class FusedMoEConfig:
has_bias: bool = False
is_act_and_mul: bool = True
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once(

View File

@ -1647,6 +1647,7 @@ def fused_experts(
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def _get_config_quant_dtype(
@ -1914,7 +1915,8 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")

View File

@ -411,11 +411,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_up_dim,
hidden_size,
dtype=params_dtype,
),
@ -425,9 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype
),
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
@ -1073,6 +1075,7 @@ class FusedMoE(CustomOp):
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
@ -1263,6 +1266,7 @@ class FusedMoE(CustomOp):
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
)
self.moe_config = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
@ -1283,6 +1287,24 @@ class FusedMoE(CustomOp):
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if not self.moe_config.is_act_and_mul:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
)
if not isinstance(
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
)
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
@ -1531,7 +1553,10 @@ class FusedMoE(CustomOp):
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
if self.moe_config.is_act_and_mul:
shard_size = expert_data.shape[shard_dim] // 2
else:
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size

View File

@ -354,7 +354,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
@ -405,10 +409,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_up_dim,
hidden_size,
dtype=weight_dtype,
),
@ -433,11 +442,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if self.moe.is_act_and_mul:
w13_weight_scale_shape = (num_experts, 2)
else:
w13_weight_scale_shape = (num_experts, 1)
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
w13_weight_scale_shape,
1.0,
dtype=torch.float32,
),
@ -485,7 +499,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2:
if (
layer.w13_weight_scale.dim() == 2
and layer.w13_weight_scale.shape[1] == 2
):
assert self.moe.is_act_and_mul, (
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE"
)
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values

View File

@ -673,7 +673,9 @@ class MixtureOfExperts(Protocol):
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
return isinstance(model, MixtureOfExperts)
return (
isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0
)
@runtime_checkable

View File

@ -18,7 +18,8 @@
# limitations under the License.
"""Inference-only NemotronH model."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
import torch
from torch import nn
@ -26,13 +27,18 @@ from torch import nn
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.config.parallel import ParallelConfig
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -54,6 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (
HasInnerState,
IsHybrid,
MixtureOfExperts,
SupportsLoRA,
SupportsPP,
SupportsQuant,
@ -61,9 +68,11 @@ from vllm.model_executor.models.interfaces import (
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
sequence_parallel_chunk,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
@ -73,28 +82,21 @@ class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
reduce_results: bool = True,
is_sequence_parallel: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
@ -102,6 +104,8 @@ class NemotronHMLP(nn.Module):
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLUSquaredActivation()
@ -113,6 +117,130 @@ class NemotronHMLP(nn.Module):
return x
class NemotronHMoE(nn.Module):
def __init__(
self,
config: NemotronHConfig,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.gate = ReplicatedLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32)
)
# Load balancing settings.
self.enable_eplb = parallel_config.enable_eplb
self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts # noqa: E501
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
if config.n_shared_experts is None or config.n_shared_experts == 0:
self.shared_experts = None
else:
intermediate_size = (
config.moe_shared_expert_intermediate_size * config.n_shared_experts
)
self.shared_experts = NemotronHMLP(
config=config,
intermediate_size=intermediate_size,
quant_config=quant_config,
reduce_results=False,
is_sequence_parallel=self.is_sequence_parallel,
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
e_score_correction_bias=self.gate.e_score_correction_bias,
activation=activation_without_mul(config.mlp_hidden_act),
is_act_and_mul=False, # non-gated MoE
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
@ -121,20 +249,70 @@ class NemotronHMLPDecoderLayer(nn.Module):
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.mixer = NemotronHMLP(
config,
intermediate_size=intermediate_size,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states)
return hidden_states, residual
class NemotronHMoEDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMoE(
config,
quant_config=quant_config,
parallel_config=parallel_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
@ -160,6 +338,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@ -174,7 +353,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
n_groups=config.n_groups,
num_heads=config.mamba_num_heads,
head_dim=config.mamba_head_dim,
rms_norm_eps=config.rms_norm_eps,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.mamba_hidden_act,
model_config=model_config,
cache_config=cache_config,
@ -182,7 +361,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
@ -281,6 +460,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@ -294,7 +474,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
@ -317,6 +497,7 @@ ALL_DECODER_LAYER_TYPES = {
"M": NemotronHMambaDecoderLayer,
"-": NemotronHMLPDecoderLayer,
"*": NemotronHAttentionDecoderLayer,
"E": NemotronHMoEDecoderLayer,
}
@ -329,6 +510,7 @@ class NemotronHModel(nn.Module):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
self.config = config
@ -346,17 +528,20 @@ class NemotronHModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.has_moe = "E" in config.hybrid_override_pattern
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.hybrid_override_pattern[layer_idx]
]
return layer_class(
config,
layer_idx,
model_config,
cache_config,
config=config,
layer_idx=layer_idx,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
parallel_config=parallel_config,
prefix=prefix,
)
@ -367,7 +552,7 @@ class NemotronHModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -413,6 +598,22 @@ class NemotronHModel(nn.Module):
("qkv_proj", "v_proj", "v"),
]
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
# what the activation is applied to
# - FusedMoe.w3 (aka up_proj) should be ignored since we're
# using non-gated MoE
ckpt_gate_proj_name="up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="",
num_experts=self.config.n_routed_experts,
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@ -438,16 +639,62 @@ class NemotronHModel(nn.Module):
# load other params
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class NemotronHForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
MixtureOfExperts,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"},
@ -545,6 +792,61 @@ class NemotronHForCausalLM(
self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors
# Set MoE hyperparameters
if self.model.has_moe:
self.expert_weights = []
self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, NemotronHMoEDecoderLayer):
# Pick last one layer since the first ones
# may be dense layers.
example_moe = layer.mixer
self.moe_layers.append(layer.mixer.experts)
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts # noqa: E501
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer, NemotronHMoEDecoderLayer):
moe = layer.mixer
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

View File

@ -185,6 +185,15 @@ class NemotronHConfig(PretrainedConfig):
mamba_proj_bias=False,
mamba_chunk_size=256,
rescale_prenorm_residual=True,
n_routed_experts=8,
n_shared_experts=1,
moe_intermediate_size=7688,
moe_shared_expert_intermediate_size=7688,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
n_group=1,
topk_group=1,
norm_topk_prob=True,
**kwargs,
):
self.vocab_size = vocab_size
@ -241,6 +250,15 @@ class NemotronHConfig(PretrainedConfig):
self.mamba_proj_bias = mamba_proj_bias
self.chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.moe_intermediate_size = moe_intermediate_size
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501
self.num_experts_per_tok = num_experts_per_tok
self.routed_scaling_factor = routed_scaling_factor
self.n_group = n_group
self.topk_group = topk_group
self.norm_topk_prob = norm_topk_prob
super().__init__(
pad_token_id=pad_token_id,
@ -258,5 +276,7 @@ class NemotronHConfig(PretrainedConfig):
else "attention"
if self.hybrid_override_pattern[i] == "*"
else "mlp"
if self.hybrid_override_pattern[i] == "-"
else "moe"
for i in range(self.num_hidden_layers)
]