From 4d5424029b7a664f1373fcdf26d97148ba5c3507 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 16 Jun 2025 07:14:18 -0400 Subject: [PATCH] [Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652) Signed-off-by: Shawn Tan --- .../model_executor/models/granitemoehybrid.py | 64 ++++++++++++------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f434b7a74e486..26b5b3ac15345 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -67,13 +67,15 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): activation=config.hidden_act, quant_config=quant_config) - self.block_sparse_moe = GraniteMoeMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + self.block_sparse_moe = None + if getattr(config, "num_local_experts", 0) > 0: + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") self.shared_mlp = None if \ getattr(config, 'shared_intermediate_size', 0) == 0 \ @@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.shared_mlp is None: - hidden_states = self.block_sparse_moe(hidden_states) + if self.block_sparse_moe is not None: + hidden_states = self.block_sparse_moe(hidden_states) + # else: skip else: # create a copy since block_sparse_moe modifies in-place - moe_hidden_states = hidden_states.clone() - moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) - del moe_hidden_states + if self.block_sparse_moe is not None: + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp( + hidden_states) + del moe_hidden_states + else: + hidden_states = self.shared_mlp(hidden_states) hidden_states = residual + hidden_states * self.residual_multiplier return hidden_states, residual @@ -137,13 +145,15 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn") - self.block_sparse_moe = GraniteMoeMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + self.block_sparse_moe = None + if getattr(config, "num_local_experts", 0) > 0: + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") self.shared_mlp = None if \ getattr(config, 'shared_intermediate_size', 0) == 0 \ @@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.shared_mlp is None: - hidden_states = self.block_sparse_moe(hidden_states) + if self.block_sparse_moe is not None: + hidden_states = self.block_sparse_moe(hidden_states) + # else: skip else: # create a copy since block_sparse_moe modifies in-place - moe_hidden_states = hidden_states.clone() - moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) - del moe_hidden_states + if self.block_sparse_moe is not None: + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp( + hidden_states) + del moe_hidden_states + else: + hidden_states = self.shared_mlp(hidden_states) hidden_states = residual + hidden_states * self.residual_multiplier return hidden_states, residual