[Model] Adds support for SlimMoE models Phi-tiny-MoE-instruct (#20286)

Signed-off-by: Zichong Li <t-lizichong@microsoft.com@Reasoning-H100-VM3.drbuo4tcjzruhloch3eo0b25ef.cx.internal.cloudapp.net>
Co-authored-by: Zichong Li <t-lizichong@microsoft.com@Reasoning-H100-VM3.drbuo4tcjzruhloch3eo0b25ef.cx.internal.cloudapp.net>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
zichongli5 2025-07-02 05:54:12 -07:00 committed by GitHub
parent ccbfb1d1c9
commit 706ff13224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -68,6 +68,7 @@ class PhiMoEConfig(PretrainedConfig):
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
@ -101,8 +102,11 @@ class PhiMoEConfig(PretrainedConfig):
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
if head_dim is None:
head_dim = hidden_size // num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
@ -294,6 +298,7 @@ class PhiMoEAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
@ -317,7 +322,9 @@ class PhiMoEAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
if head_dim is None:
head_dim = hidden_size // num_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@ -387,6 +394,8 @@ class PhiMoEDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
head_dim=getattr(config, "head_dim",
self.hidden_size // config.num_attention_heads),
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,