[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)

Signed-off-by: Shawn Tan <shawntan@ibm.com>
This commit is contained in:
Shawn Tan 2025-06-16 07:14:18 -04:00 committed by GitHub
parent 3e7506975c
commit 4d5424029b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,13 +67,15 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
activation=config.hidden_act,
quant_config=quant_config)
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.shared_mlp = None if \
getattr(config, 'shared_intermediate_size', 0) == 0 \
@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
hidden_states = self.block_sparse_moe(hidden_states)
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
del moe_hidden_states
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual
@ -137,13 +145,15 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.shared_mlp = None if \
getattr(config, 'shared_intermediate_size', 0) == 0 \
@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
hidden_states = self.block_sparse_moe(hidden_states)
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
del moe_hidden_states
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual