From a4528f0cac5d2857ccc56d2a2e1a1c43142643ce Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 29 Jul 2025 18:13:27 +0800 Subject: [PATCH] [Model]: Fused MoE for nomic-embed-text-v2-moe (#18321) Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py --- .../layers/fused_moe/fused_moe.py | 47 +++- vllm/model_executor/models/bert_with_rope.py | 204 +++++++++--------- 2 files changed, 140 insertions(+), 111 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1985e8612da35..227aacf25c0b0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -7,6 +7,7 @@ import os from typing import Any, Callable, Optional import torch +import torch.nn.functional as F import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, apply_router_weight_on_input, use_fp8_w8a8, + activation, is_act_and_mul, + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, @@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1167,6 +1171,7 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1183,13 +1188,12 @@ def outplace_fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None) -> torch.Tensor: - return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, activation, apply_router_weight_on_input, - use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, use_mxfp4_w4a4, - per_channel_quant, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape) + return fused_experts_impl( + hidden_states, w1, w2, topk_weights, topk_ids, False, activation, + is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, + per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + is_act_and_mul: bool = True, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1253,6 +1258,7 @@ def fused_experts( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1283,6 +1289,8 @@ def fused_experts( or is_blackwell_deep_gemm_used()) if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): assert apply_router_weight_on_input is False + assert is_act_and_mul, ( + "DeepGemm only supports is_act_and_mul=True for now.") return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, @@ -1319,6 +1327,7 @@ def fused_experts( topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, + is_act_and_mul=is_act_and_mul, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, @@ -1345,6 +1354,7 @@ def fused_experts_impl( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1503,14 +1513,21 @@ def fused_experts_impl( per_channel_quant=per_channel_quant, block_shape=block_shape) - if activation == "silu": + # Activation function with multiplication + if activation == "silu" and is_act_and_mul: torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "gelu": + elif activation == "gelu" and is_act_and_mul: torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + # Activation function without multiplication + elif activation == "silu": + intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) + elif activation == "gelu": + intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + raise ValueError(f"Unsupported FusedMoe activation: {activation}, " + f"with is_act_and_mul={is_act_and_mul}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, @@ -1555,6 +1572,7 @@ def fused_moe( renormalize: bool, inplace: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -1591,6 +1609,9 @@ def fused_moe( Defaults to False. - activation (str): The activation function to apply after the first MoE layer. + - is_act_and_mul (bool): If True, use activation-and-mul function for + activation (self-gated activation), otherwise use activation function + for activation (ungated activation). - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk @@ -1627,6 +1648,9 @@ def fused_moe( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + if not is_act_and_mul: + assert inplace is False, ( + "is_act_and_mul=False is not supported with inplace=True") if use_grouped_topk: assert num_expert_group is not None and topk_group is not None @@ -1647,6 +1671,7 @@ def fused_moe( topk_ids, inplace=inplace, activation=activation, + is_act_and_mul=is_act_and_mul, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 0b7350f07d3f6..5249acbd84a56 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -10,9 +10,12 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import (get_act_and_mul_fn, get_act_fn) +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -26,6 +29,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsV0Only from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -201,114 +206,101 @@ class BertWithRopeMLP(nn.Module): return hidden_states -class NomicRouter(nn.Module): +class NomicMoE(nn.Module): - def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): super().__init__() - self.moe_top_k = moe_top_k - self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False) - def forward( - self, x: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: - weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax( - dim=-1, dtype=torch.float32) - top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) - weights = weights.to(x.dtype) - top_weights = top_weights.to(x.dtype) - return weights, top_weights, top_experts # type: ignore - - -class NomicExpertMLP(nn.Module): - - def __init__(self, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int, ffn_act_fn: str): - super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.moe_num_experts = moe_num_experts + self.total_intermediate_size = intermediate_size + self.intermediate_size = divide(intermediate_size, self.tp_size) + self.hidden_act = hidden_act + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False) self.w1 = nn.Parameter( - torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + torch.empty(self.num_total_experts, + self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype)) self.w2 = nn.Parameter( - torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) - self.activation_fn = get_act_fn(ffn_act_fn) + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + self.bias = nn.Parameter(torch.zeros(self.hidden_size)) + set_weight_attrs(self.w1, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2, { + "weight_loader": self.weight_loader, + }) - def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + ): + # NOTE: Nomic-MoE has fused experts weights with shape + # (num_experts * intermediate_size, hidden_size) + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1"): + loaded_weight = loaded_weight.reshape( + self.num_total_experts, + self.total_intermediate_size, + self.hidden_size, + )[:, shard] + if weight_name.endswith("w2"): + loaded_weight = loaded_weight.reshape( + self.num_total_experts, + self.total_intermediate_size, + self.hidden_size, + )[:, shard].transpose(1, 2) + param_data.copy_(loaded_weight) - x1 = x.matmul(expert_w1.t()) - act_out = self.activation_fn(x1) - x2 = act_out.matmul(expert_w2) - return x2 + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.router(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=False, + inplace=False, + activation=self.hidden_act, + is_act_and_mul=False) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) -class NomicExperts(nn.Module): - - def __init__(self, config, hidden_size: int, ffn_hidden_size: int, - moe_num_experts: int): - super().__init__() - self.moe_num_experts = moe_num_experts - - self.mlp = NomicExpertMLP(hidden_size=config.n_embd, - ffn_hidden_size=config.n_inner, - moe_num_experts=moe_num_experts, - ffn_act_fn=config.hidden_act) - self.bias = nn.Parameter(torch.zeros(config.n_embd)) - - def forward(self, x: torch.Tensor, weights: torch.Tensor, - top_weights: torch.Tensor, - top_experts: torch.LongTensor) -> torch.Tensor: - q_len, hidden_size = x.shape - x = x.view(-1, hidden_size) - out = torch.zeros_like(x) - - expert_mask = nn.functional.one_hot( - top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) - for expert_idx in range(0, self.moe_num_experts): - topk_idx, token_idx = torch.where(expert_mask[expert_idx]) - if token_idx.shape[0] == 0: - continue - - token_list = token_idx.tolist() - topk_list = topk_idx.tolist() - - expert_tokens = x[None, token_list].reshape(-1, hidden_size) - expert_out = self.mlp( - expert_tokens, expert_idx) * top_weights[token_list, topk_list, - None] - - out.index_add_(0, token_idx, expert_out) - - out = out.reshape(q_len, hidden_size) - return out + self.bias - - -class NomicMoELayer(nn.Module): - - def __init__(self, config: PretrainedConfig): - super().__init__() - - self.router = NomicRouter( - config.n_embd, - moe_num_experts=config.num_experts, - moe_top_k=config.moe_top_k, - ) - - self.experts = NomicExperts( - config, - hidden_size=config.n_embd, - ffn_hidden_size=config.n_inner, - moe_num_experts=config.num_experts, - ) - - def forward(self, x: torch.Tensor): - weights, top_weights, top_experts = self.router(x) - out = self.experts(x, weights, top_weights, top_experts) - return out + return final_hidden_states.view(num_tokens, hidden_size) + self.bias class BertWithRopeBlock(nn.Module): @@ -332,7 +324,11 @@ class BertWithRopeBlock(nn.Module): prefix=f"{prefix}.attention") if moe: - self.mlp = NomicMoELayer(config=config, ) + self.mlp = NomicMoE(num_experts=config.num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act) else: if config.hidden_act in ["silu", "geglu"]: self.mlp = BertWithRopeGatedMLP( @@ -463,7 +459,11 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + if name.endswith((".w1", ".w2")): + # Nomic-MoE has fused experts weights + weight_loader(param, loaded_weight, name) + else: + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -481,6 +481,10 @@ class NomicBertModel(BertWithRope): "mlp.fc12": "mlp.gate_proj", "mlp.fc2": "mlp.down_proj", "norm2": "mlp_ln", + # MoE mapping + "experts.mlp.": "", + "experts.": "", + "router.layer": "router", })