mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Model]: Fused MoE for nomic-embed-text-v2-moe (#18321)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
a2480251ec
commit
a4528f0cac
@ -7,6 +7,7 @@ import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
activation, is_act_and_mul,
|
||||
apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1167,6 +1171,7 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1183,13 +1188,12 @@ def outplace_fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
return fused_experts_impl(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@ -1253,6 +1258,7 @@ def fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1283,6 +1289,8 @@ def fused_experts(
|
||||
or is_blackwell_deep_gemm_used())
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||
assert apply_router_weight_on_input is False
|
||||
assert is_act_and_mul, (
|
||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@ -1319,6 +1327,7 @@ def fused_experts(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
@ -1345,6 +1354,7 @@ def fused_experts_impl(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1503,14 +1513,21 @@ def fused_experts_impl(
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
if activation == "silu":
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
elif activation == "gelu" and is_act_and_mul:
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
# Activation function without multiplication
|
||||
elif activation == "silu":
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
@ -1555,6 +1572,7 @@ def fused_moe(
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
@ -1591,6 +1609,9 @@ def fused_moe(
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- is_act_and_mul (bool): If True, use activation-and-mul function for
|
||||
activation (self-gated activation), otherwise use activation function
|
||||
for activation (ungated activation).
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
@ -1627,6 +1648,9 @@ def fused_moe(
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
if not is_act_and_mul:
|
||||
assert inplace is False, (
|
||||
"is_act_and_mul=False is not supported with inplace=True")
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
@ -1647,6 +1671,7 @@ def fused_moe(
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
|
||||
@ -10,9 +10,12 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -26,6 +29,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsV0Only
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@ -201,114 +206,101 @@ class BertWithRopeMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NomicRouter(nn.Module):
|
||||
class NomicMoE(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_top_k = moe_top_k
|
||||
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
||||
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
|
||||
dim=-1, dtype=torch.float32)
|
||||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
||||
weights = weights.to(x.dtype)
|
||||
top_weights = top_weights.to(x.dtype)
|
||||
return weights, top_weights, top_experts # type: ignore
|
||||
|
||||
|
||||
class NomicExpertMLP(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int, ffn_act_fn: str):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.total_intermediate_size = intermediate_size
|
||||
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
self.router = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False)
|
||||
self.w1 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
torch.empty(self.num_total_experts,
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.w2 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
self.activation_fn = get_act_fn(ffn_act_fn)
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.bias = nn.Parameter(torch.zeros(self.hidden_size))
|
||||
set_weight_attrs(self.w1, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
||||
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
def weight_loader(
|
||||
self,
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
):
|
||||
# NOTE: Nomic-MoE has fused experts weights with shape
|
||||
# (num_experts * intermediate_size, hidden_size)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1"):
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
self.num_total_experts,
|
||||
self.total_intermediate_size,
|
||||
self.hidden_size,
|
||||
)[:, shard]
|
||||
if weight_name.endswith("w2"):
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
self.num_total_experts,
|
||||
self.total_intermediate_size,
|
||||
self.hidden_size,
|
||||
)[:, shard].transpose(1, 2)
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
x1 = x.matmul(expert_w1.t())
|
||||
act_out = self.activation_fn(x1)
|
||||
x2 = act_out.matmul(expert_w2)
|
||||
return x2
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=False,
|
||||
inplace=False,
|
||||
activation=self.hidden_act,
|
||||
is_act_and_mul=False)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
class NomicExperts(nn.Module):
|
||||
|
||||
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int):
|
||||
super().__init__()
|
||||
self.moe_num_experts = moe_num_experts
|
||||
|
||||
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=moe_num_experts,
|
||||
ffn_act_fn=config.hidden_act)
|
||||
self.bias = nn.Parameter(torch.zeros(config.n_embd))
|
||||
|
||||
def forward(self, x: torch.Tensor, weights: torch.Tensor,
|
||||
top_weights: torch.Tensor,
|
||||
top_experts: torch.LongTensor) -> torch.Tensor:
|
||||
q_len, hidden_size = x.shape
|
||||
x = x.view(-1, hidden_size)
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
expert_mask = nn.functional.one_hot(
|
||||
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
||||
for expert_idx in range(0, self.moe_num_experts):
|
||||
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
||||
if token_idx.shape[0] == 0:
|
||||
continue
|
||||
|
||||
token_list = token_idx.tolist()
|
||||
topk_list = topk_idx.tolist()
|
||||
|
||||
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
||||
expert_out = self.mlp(
|
||||
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
|
||||
None]
|
||||
|
||||
out.index_add_(0, token_idx, expert_out)
|
||||
|
||||
out = out.reshape(q_len, hidden_size)
|
||||
return out + self.bias
|
||||
|
||||
|
||||
class NomicMoELayer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.router = NomicRouter(
|
||||
config.n_embd,
|
||||
moe_num_experts=config.num_experts,
|
||||
moe_top_k=config.moe_top_k,
|
||||
)
|
||||
|
||||
self.experts = NomicExperts(
|
||||
config,
|
||||
hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=config.num_experts,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
weights, top_weights, top_experts = self.router(x)
|
||||
out = self.experts(x, weights, top_weights, top_experts)
|
||||
return out
|
||||
return final_hidden_states.view(num_tokens, hidden_size) + self.bias
|
||||
|
||||
|
||||
class BertWithRopeBlock(nn.Module):
|
||||
@ -332,7 +324,11 @@ class BertWithRopeBlock(nn.Module):
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
if moe:
|
||||
self.mlp = NomicMoELayer(config=config, )
|
||||
self.mlp = NomicMoE(num_experts=config.num_experts,
|
||||
top_k=config.moe_top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act)
|
||||
else:
|
||||
if config.hidden_act in ["silu", "geglu"]:
|
||||
self.mlp = BertWithRopeGatedMLP(
|
||||
@ -463,7 +459,11 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.endswith((".w1", ".w2")):
|
||||
# Nomic-MoE has fused experts weights
|
||||
weight_loader(param, loaded_weight, name)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -481,6 +481,10 @@ class NomicBertModel(BertWithRope):
|
||||
"mlp.fc12": "mlp.gate_proj",
|
||||
"mlp.fc2": "mlp.down_proj",
|
||||
"norm2": "mlp_ln",
|
||||
# MoE mapping
|
||||
"experts.mlp.": "",
|
||||
"experts.": "",
|
||||
"router.layer": "router",
|
||||
})
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user