mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:25:01 +08:00
Add latent MoE support (#30203)
Signed-off-by: Shahar Mor <smor@nvidia.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
398a596ed2
commit
fcd5306f65
@ -83,6 +83,7 @@ class NemotronHMLP(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
@ -93,7 +94,7 @@ class NemotronHMLP(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.up_proj = ColumnParallelLinear(
|
self.up_proj = ColumnParallelLinear(
|
||||||
input_size=config.hidden_size,
|
input_size=hidden_size,
|
||||||
output_size=intermediate_size,
|
output_size=intermediate_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -102,7 +103,7 @@ class NemotronHMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
input_size=intermediate_size,
|
input_size=intermediate_size,
|
||||||
output_size=config.hidden_size,
|
output_size=hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
@ -135,6 +136,10 @@ class NemotronHMoE(nn.Module):
|
|||||||
self.ep_size = self.ep_group.size()
|
self.ep_size = self.ep_group.size()
|
||||||
self.n_routed_experts: int = config.n_routed_experts
|
self.n_routed_experts: int = config.n_routed_experts
|
||||||
self.n_shared_experts: int = config.n_shared_experts
|
self.n_shared_experts: int = config.n_shared_experts
|
||||||
|
self.use_latent_moe: bool = getattr(config, "moe_latent_size", None) is not None
|
||||||
|
self.moe_hidden_size: int = (
|
||||||
|
config.moe_latent_size if self.use_latent_moe else config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||||
|
|
||||||
@ -172,6 +177,7 @@ class NemotronHMoE(nn.Module):
|
|||||||
|
|
||||||
self.shared_experts = NemotronHMLP(
|
self.shared_experts = NemotronHMLP(
|
||||||
config=config,
|
config=config,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
@ -180,10 +186,12 @@ class NemotronHMoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.experts = SharedFusedMoE(
|
self.experts = SharedFusedMoE(
|
||||||
shared_experts=self.shared_experts,
|
# TODO: make it possible for shared experts to have
|
||||||
|
# different input in SharedFusedMoE
|
||||||
|
shared_experts=self.shared_experts if not self.use_latent_moe else None,
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=self.moe_hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
@ -201,6 +209,32 @@ class NemotronHMoE(nn.Module):
|
|||||||
is_sequence_parallel=self.is_sequence_parallel,
|
is_sequence_parallel=self.is_sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_latent_moe:
|
||||||
|
# TODO: check if using ReplicatedLinear is better than
|
||||||
|
# ColumnParallelLinear + all_gather
|
||||||
|
self.fc1_latent_proj = ColumnParallelLinear(
|
||||||
|
input_size=config.hidden_size,
|
||||||
|
output_size=self.moe_hidden_size,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
disable_tp=self.is_sequence_parallel,
|
||||||
|
# We need to gather the output to prepare input for moe
|
||||||
|
gather_output=True,
|
||||||
|
prefix=f"{prefix}.fc1_latent_proj",
|
||||||
|
)
|
||||||
|
self.fc2_latent_proj = ReplicatedLinear(
|
||||||
|
input_size=self.moe_hidden_size,
|
||||||
|
output_size=config.hidden_size,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
disable_tp=self.is_sequence_parallel,
|
||||||
|
prefix=f"{prefix}.fc2_latent_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.fc1_latent_proj = None
|
||||||
|
self.fc2_latent_proj = None
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
@ -210,12 +244,20 @@ class NemotronHMoE(nn.Module):
|
|||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
||||||
|
shared_output = None
|
||||||
|
if self.use_latent_moe:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
hidden_states, _ = self.fc1_latent_proj(hidden_states)
|
||||||
|
|
||||||
fused_moe_out = self.experts(
|
fused_moe_out = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
shared_output, final_hidden_states = fused_moe_out
|
if self.use_latent_moe:
|
||||||
|
_, final_hidden_states = fused_moe_out
|
||||||
|
else:
|
||||||
|
shared_output, final_hidden_states = fused_moe_out
|
||||||
|
|
||||||
# Fix FP16 overflow
|
# Fix FP16 overflow
|
||||||
# See DeepseekV2DecoderLayer for more details.
|
# See DeepseekV2DecoderLayer for more details.
|
||||||
@ -225,6 +267,13 @@ class NemotronHMoE(nn.Module):
|
|||||||
assert shared_output is not None
|
assert shared_output is not None
|
||||||
shared_output *= 1.0 / self.routed_scaling_factor
|
shared_output *= 1.0 / self.routed_scaling_factor
|
||||||
|
|
||||||
|
# TODO: currently latent up_proj is done before all-reduce for simplicity.
|
||||||
|
# if and when shared experts will be part of SharedFusedMoE,
|
||||||
|
# we should do the up_proj after all-reduce,
|
||||||
|
# to have the all-reduce in the smaller latent dimension.
|
||||||
|
if self.use_latent_moe:
|
||||||
|
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
assert shared_output is not None
|
assert shared_output is not None
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
@ -268,6 +317,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.mixer = NemotronHMLP(
|
self.mixer = NemotronHMLP(
|
||||||
config,
|
config,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=config.mlp_bias,
|
bias=config.mlp_bias,
|
||||||
@ -846,5 +896,5 @@ class NemotronHForCausalLM(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self, skip_prefixes=["mtp"])
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|||||||
@ -189,6 +189,7 @@ class NemotronHConfig(PretrainedConfig):
|
|||||||
n_shared_experts=1,
|
n_shared_experts=1,
|
||||||
moe_intermediate_size=7688,
|
moe_intermediate_size=7688,
|
||||||
moe_shared_expert_intermediate_size=7688,
|
moe_shared_expert_intermediate_size=7688,
|
||||||
|
moe_latent_size=None,
|
||||||
num_experts_per_tok=2,
|
num_experts_per_tok=2,
|
||||||
routed_scaling_factor=1.0,
|
routed_scaling_factor=1.0,
|
||||||
n_group=1,
|
n_group=1,
|
||||||
@ -254,6 +255,7 @@ class NemotronHConfig(PretrainedConfig):
|
|||||||
self.n_shared_experts = n_shared_experts
|
self.n_shared_experts = n_shared_experts
|
||||||
self.moe_intermediate_size = moe_intermediate_size
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501
|
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501
|
||||||
|
self.moe_latent_size = moe_latent_size
|
||||||
self.num_experts_per_tok = num_experts_per_tok
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
self.n_group = n_group
|
self.n_group = n_group
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user