[Model]: Fused MoE for nomic-embed-text-v2-moe (#18321)

Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-07-29 18:13:27 +08:00 committed by GitHub
parent a2480251ec
commit a4528f0cac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 140 additions and 111 deletions

View File

@ -7,6 +7,7 @@ import os
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1167,6 +1171,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1183,13 +1188,12 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
def outplace_fused_experts_fake(
@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@ -1253,6 +1258,7 @@ def fused_experts(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1283,6 +1289,8 @@ def fused_experts(
or is_blackwell_deep_gemm_used())
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
@ -1319,6 +1327,7 @@ def fused_experts(
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
is_act_and_mul=is_act_and_mul,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
@ -1345,6 +1354,7 @@ def fused_experts_impl(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@ -1503,14 +1513,21 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if activation == "silu":
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
elif activation == "gelu" and is_act_and_mul:
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
# Activation function without multiplication
elif activation == "silu":
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
@ -1555,6 +1572,7 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
@ -1591,6 +1609,9 @@ def fused_moe(
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- is_act_and_mul (bool): If True, use activation-and-mul function for
activation (self-gated activation), otherwise use activation function
for activation (ungated activation).
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
@ -1627,6 +1648,9 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if not is_act_and_mul:
assert inplace is False, (
"is_act_and_mul=False is not supported with inplace=True")
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
@ -1647,6 +1671,7 @@ def fused_moe(
topk_ids,
inplace=inplace,
activation=activation,
is_act_and_mul=is_act_and_mul,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,

View File

@ -10,9 +10,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
@ -26,6 +29,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsV0Only
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@ -201,114 +206,101 @@ class BertWithRopeMLP(nn.Module):
return hidden_states
class NomicRouter(nn.Module):
class NomicMoE(nn.Module):
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.moe_top_k = moe_top_k
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
weights = weights.to(x.dtype)
top_weights = top_weights.to(x.dtype)
return weights, top_weights, top_experts # type: ignore
class NomicExpertMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int, ffn_act_fn: str):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.total_intermediate_size = intermediate_size
self.intermediate_size = divide(intermediate_size, self.tp_size)
self.hidden_act = hidden_act
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False)
self.w1 = nn.Parameter(
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
torch.empty(self.num_total_experts,
self.intermediate_size,
self.hidden_size,
device=current_platform.device_type,
dtype=self.params_dtype))
self.w2 = nn.Parameter(
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.activation_fn = get_act_fn(ffn_act_fn)
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device=current_platform.device_type,
dtype=self.params_dtype))
self.bias = nn.Parameter(torch.zeros(self.hidden_size))
set_weight_attrs(self.w1, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2, {
"weight_loader": self.weight_loader,
})
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)[expert_idx]
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
):
# NOTE: Nomic-MoE has fused experts weights with shape
# (num_experts * intermediate_size, hidden_size)
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1"):
loaded_weight = loaded_weight.reshape(
self.num_total_experts,
self.total_intermediate_size,
self.hidden_size,
)[:, shard]
if weight_name.endswith("w2"):
loaded_weight = loaded_weight.reshape(
self.num_total_experts,
self.total_intermediate_size,
self.hidden_size,
)[:, shard].transpose(1, 2)
param_data.copy_(loaded_weight)
x1 = x.matmul(expert_w1.t())
act_out = self.activation_fn(x1)
x2 = act_out.matmul(expert_w2)
return x2
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=False,
inplace=False,
activation=self.hidden_act,
is_act_and_mul=False)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
class NomicExperts(nn.Module):
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int):
super().__init__()
self.moe_num_experts = moe_num_experts
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
ffn_hidden_size=config.n_inner,
moe_num_experts=moe_num_experts,
ffn_act_fn=config.hidden_act)
self.bias = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x: torch.Tensor, weights: torch.Tensor,
top_weights: torch.Tensor,
top_experts: torch.LongTensor) -> torch.Tensor:
q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = nn.functional.one_hot(
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = self.mlp(
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
None]
out.index_add_(0, token_idx, expert_out)
out = out.reshape(q_len, hidden_size)
return out + self.bias
class NomicMoELayer(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.router = NomicRouter(
config.n_embd,
moe_num_experts=config.num_experts,
moe_top_k=config.moe_top_k,
)
self.experts = NomicExperts(
config,
hidden_size=config.n_embd,
ffn_hidden_size=config.n_inner,
moe_num_experts=config.num_experts,
)
def forward(self, x: torch.Tensor):
weights, top_weights, top_experts = self.router(x)
out = self.experts(x, weights, top_weights, top_experts)
return out
return final_hidden_states.view(num_tokens, hidden_size) + self.bias
class BertWithRopeBlock(nn.Module):
@ -332,7 +324,11 @@ class BertWithRopeBlock(nn.Module):
prefix=f"{prefix}.attention")
if moe:
self.mlp = NomicMoELayer(config=config, )
self.mlp = NomicMoE(num_experts=config.num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act)
else:
if config.hidden_act in ["silu", "geglu"]:
self.mlp = BertWithRopeGatedMLP(
@ -463,7 +459,11 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.endswith((".w1", ".w2")):
# Nomic-MoE has fused experts weights
weight_loader(param, loaded_weight, name)
else:
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@ -481,6 +481,10 @@ class NomicBertModel(BertWithRope):
"mlp.fc12": "mlp.gate_proj",
"mlp.fc2": "mlp.down_proj",
"norm2": "mlp_ln",
# MoE mapping
"experts.mlp.": "",
"experts.": "",
"router.layer": "router",
})