Add latent MoE support (#30203)

Signed-off-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
shaharmor98 2025-12-08 19:35:01 +02:00 committed by GitHub
parent 398a596ed2
commit fcd5306f65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 6 deletions

View File

@ -83,6 +83,7 @@ class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
hidden_size: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@ -93,7 +94,7 @@ class NemotronHMLP(nn.Module):
super().__init__()
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
@ -102,7 +103,7 @@ class NemotronHMLP(nn.Module):
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=config.hidden_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
@ -135,6 +136,10 @@ class NemotronHMoE(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
self.use_latent_moe: bool = getattr(config, "moe_latent_size", None) is not None
self.moe_hidden_size: int = (
config.moe_latent_size if self.use_latent_moe else config.hidden_size
)
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
@ -172,6 +177,7 @@ class NemotronHMoE(nn.Module):
self.shared_experts = NemotronHMLP(
config=config,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
quant_config=quant_config,
reduce_results=False,
@ -180,10 +186,12 @@ class NemotronHMoE(nn.Module):
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
# TODO: make it possible for shared experts to have
# different input in SharedFusedMoE
shared_experts=self.shared_experts if not self.use_latent_moe else None,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
hidden_size=self.moe_hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
@ -201,6 +209,32 @@ class NemotronHMoE(nn.Module):
is_sequence_parallel=self.is_sequence_parallel,
)
if self.use_latent_moe:
# TODO: check if using ReplicatedLinear is better than
# ColumnParallelLinear + all_gather
self.fc1_latent_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=self.moe_hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
# We need to gather the output to prepare input for moe
gather_output=True,
prefix=f"{prefix}.fc1_latent_proj",
)
self.fc2_latent_proj = ReplicatedLinear(
input_size=self.moe_hidden_size,
output_size=config.hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
prefix=f"{prefix}.fc2_latent_proj",
)
else:
self.fc1_latent_proj = None
self.fc2_latent_proj = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
@ -210,12 +244,20 @@ class NemotronHMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
shared_output = None
if self.use_latent_moe:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
hidden_states, _ = self.fc1_latent_proj(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.use_latent_moe:
_, final_hidden_states = fused_moe_out
else:
shared_output, final_hidden_states = fused_moe_out
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
@ -225,6 +267,13 @@ class NemotronHMoE(nn.Module):
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
# TODO: currently latent up_proj is done before all-reduce for simplicity.
# if and when shared experts will be part of SharedFusedMoE,
# we should do the up_proj after all-reduce,
# to have the all-reduce in the smaller latent dimension.
if self.use_latent_moe:
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
@ -268,6 +317,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
self.mixer = NemotronHMLP(
config,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
quant_config=quant_config,
bias=config.mlp_bias,
@ -846,5 +896,5 @@ class NemotronHForCausalLM(
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
loader = AutoWeightsLoader(self, skip_prefixes=["mtp"])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -189,6 +189,7 @@ class NemotronHConfig(PretrainedConfig):
n_shared_experts=1,
moe_intermediate_size=7688,
moe_shared_expert_intermediate_size=7688,
moe_latent_size=None,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
n_group=1,
@ -254,6 +255,7 @@ class NemotronHConfig(PretrainedConfig):
self.n_shared_experts = n_shared_experts
self.moe_intermediate_size = moe_intermediate_size
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501
self.moe_latent_size = moe_latent_size
self.num_experts_per_tok = num_experts_per_tok
self.routed_scaling_factor = routed_scaling_factor
self.n_group = n_group