diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index baeb901bbb05a..2d9dfbd3e7688 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -83,6 +83,7 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, + hidden_size: int, intermediate_size: int, quant_config: QuantizationConfig | None = None, bias: bool = False, @@ -93,7 +94,7 @@ class NemotronHMLP(nn.Module): super().__init__() self.up_proj = ColumnParallelLinear( - input_size=config.hidden_size, + input_size=hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, @@ -102,7 +103,7 @@ class NemotronHMLP(nn.Module): ) self.down_proj = RowParallelLinear( input_size=intermediate_size, - output_size=config.hidden_size, + output_size=hidden_size, bias=bias, quant_config=quant_config, reduce_results=reduce_results, @@ -135,6 +136,10 @@ class NemotronHMoE(nn.Module): self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + self.use_latent_moe: bool = getattr(config, "moe_latent_size", None) is not None + self.moe_hidden_size: int = ( + config.moe_latent_size if self.use_latent_moe else config.hidden_size + ) self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe @@ -172,6 +177,7 @@ class NemotronHMoE(nn.Module): self.shared_experts = NemotronHMLP( config=config, + hidden_size=config.hidden_size, intermediate_size=intermediate_size, quant_config=quant_config, reduce_results=False, @@ -180,10 +186,12 @@ class NemotronHMoE(nn.Module): ) self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, + # TODO: make it possible for shared experts to have + # different input in SharedFusedMoE + shared_experts=self.shared_experts if not self.use_latent_moe else None, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, + hidden_size=self.moe_hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, @@ -201,6 +209,32 @@ class NemotronHMoE(nn.Module): is_sequence_parallel=self.is_sequence_parallel, ) + if self.use_latent_moe: + # TODO: check if using ReplicatedLinear is better than + # ColumnParallelLinear + all_gather + self.fc1_latent_proj = ColumnParallelLinear( + input_size=config.hidden_size, + output_size=self.moe_hidden_size, + bias=config.mlp_bias, + quant_config=quant_config, + disable_tp=self.is_sequence_parallel, + # We need to gather the output to prepare input for moe + gather_output=True, + prefix=f"{prefix}.fc1_latent_proj", + ) + self.fc2_latent_proj = ReplicatedLinear( + input_size=self.moe_hidden_size, + output_size=config.hidden_size, + bias=config.mlp_bias, + quant_config=quant_config, + disable_tp=self.is_sequence_parallel, + prefix=f"{prefix}.fc2_latent_proj", + ) + + else: + self.fc1_latent_proj = None + self.fc2_latent_proj = None + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -210,12 +244,20 @@ class NemotronHMoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + shared_output = None + if self.use_latent_moe: + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + hidden_states, _ = self.fc1_latent_proj(hidden_states) fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - shared_output, final_hidden_states = fused_moe_out + if self.use_latent_moe: + _, final_hidden_states = fused_moe_out + else: + shared_output, final_hidden_states = fused_moe_out # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. @@ -225,6 +267,13 @@ class NemotronHMoE(nn.Module): assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor + # TODO: currently latent up_proj is done before all-reduce for simplicity. + # if and when shared experts will be part of SharedFusedMoE, + # we should do the up_proj after all-reduce, + # to have the all-reduce in the smaller latent dimension. + if self.use_latent_moe: + final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states) + if self.shared_experts is not None: assert shared_output is not None final_hidden_states += shared_output @@ -268,6 +317,7 @@ class NemotronHMLPDecoderLayer(nn.Module): self.mixer = NemotronHMLP( config, + hidden_size=config.hidden_size, intermediate_size=intermediate_size, quant_config=quant_config, bias=config.mlp_bias, @@ -846,5 +896,5 @@ class NemotronHForCausalLM( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_prefixes=["mtp"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 68c40002098c8..86c117fd9d59f 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -189,6 +189,7 @@ class NemotronHConfig(PretrainedConfig): n_shared_experts=1, moe_intermediate_size=7688, moe_shared_expert_intermediate_size=7688, + moe_latent_size=None, num_experts_per_tok=2, routed_scaling_factor=1.0, n_group=1, @@ -254,6 +255,7 @@ class NemotronHConfig(PretrainedConfig): self.n_shared_experts = n_shared_experts self.moe_intermediate_size = moe_intermediate_size self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501 + self.moe_latent_size = moe_latent_size self.num_experts_per_tok = num_experts_per_tok self.routed_scaling_factor = routed_scaling_factor self.n_group = n_group