[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)

Signed-off-by: Shawn Tan <shawntan@ibm.com>
This commit is contained in:
Shawn Tan 2025-06-16 07:14:18 -04:00 committed by GitHub
parent 3e7506975c
commit 4d5424029b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,6 +67,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
activation=config.hidden_act,
quant_config=quant_config)
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual
@ -137,6 +145,8 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual