mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:05:01 +08:00
[Model]: Fused MoE for nomic-embed-text-v2-moe (#18321)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
a2480251ec
commit
a4528f0cac
@ -7,6 +7,7 @@ import os
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None) -> None:
|
block_shape: Optional[list[int]] = None) -> None:
|
||||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
activation, is_act_and_mul,
|
||||||
|
apply_router_weight_on_input, use_fp8_w8a8,
|
||||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||||
@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1167,6 +1171,7 @@ def outplace_fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1183,13 +1188,12 @@ def outplace_fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
return fused_experts_impl(
|
||||||
False, activation, apply_router_weight_on_input,
|
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||||
use_int4_w4a16, use_mxfp4_w4a4,
|
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
||||||
per_channel_quant, global_num_experts,
|
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
||||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||||
a1_scale, a2_scale, block_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts_fake(
|
def outplace_fused_experts_fake(
|
||||||
@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@ -1253,6 +1258,7 @@ def fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1283,6 +1289,8 @@ def fused_experts(
|
|||||||
or is_blackwell_deep_gemm_used())
|
or is_blackwell_deep_gemm_used())
|
||||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
|
assert is_act_and_mul, (
|
||||||
|
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||||
return deep_gemm_moe_fp8(
|
return deep_gemm_moe_fp8(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -1319,6 +1327,7 @@ def fused_experts(
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
is_act_and_mul=is_act_and_mul,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
@ -1345,6 +1354,7 @@ def fused_experts_impl(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1503,14 +1513,21 @@ def fused_experts_impl(
|
|||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
if activation == "silu":
|
# Activation function with multiplication
|
||||||
|
if activation == "silu" and is_act_and_mul:
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
elif activation == "gelu":
|
elif activation == "gelu" and is_act_and_mul:
|
||||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
|
# Activation function without multiplication
|
||||||
|
elif activation == "silu":
|
||||||
|
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||||
|
elif activation == "gelu":
|
||||||
|
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||||
|
f"with is_act_and_mul={is_act_and_mul}.")
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
@ -1555,6 +1572,7 @@ def fused_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
is_act_and_mul: bool = True,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
@ -1591,6 +1609,9 @@ def fused_moe(
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
- activation (str): The activation function to apply after the first
|
- activation (str): The activation function to apply after the first
|
||||||
MoE layer.
|
MoE layer.
|
||||||
|
- is_act_and_mul (bool): If True, use activation-and-mul function for
|
||||||
|
activation (self-gated activation), otherwise use activation function
|
||||||
|
for activation (ungated activation).
|
||||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
@ -1627,6 +1648,9 @@ def fused_moe(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
|
if not is_act_and_mul:
|
||||||
|
assert inplace is False, (
|
||||||
|
"is_act_and_mul=False is not supported with inplace=True")
|
||||||
|
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
@ -1647,6 +1671,7 @@ def fused_moe(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
is_act_and_mul=is_act_and_mul,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
|||||||
@ -10,9 +10,12 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||||
get_act_fn)
|
get_act_fn)
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -26,6 +29,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models import SupportsV0Only
|
from vllm.model_executor.models import SupportsV0Only
|
||||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
|
||||||
@ -201,114 +206,101 @@ class BertWithRopeMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class NomicRouter(nn.Module):
|
class NomicMoE(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.moe_top_k = moe_top_k
|
|
||||||
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
|
|
||||||
|
|
||||||
def forward(
|
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||||
self, x: torch.Tensor
|
self.num_total_experts = num_experts
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
self.top_k = top_k
|
||||||
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
|
|
||||||
dim=-1, dtype=torch.float32)
|
|
||||||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
|
||||||
weights = weights.to(x.dtype)
|
|
||||||
top_weights = top_weights.to(x.dtype)
|
|
||||||
return weights, top_weights, top_experts # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class NomicExpertMLP(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, ffn_hidden_size: int,
|
|
||||||
moe_num_experts: int, ffn_act_fn: str):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.ffn_hidden_size = ffn_hidden_size
|
self.total_intermediate_size = intermediate_size
|
||||||
self.moe_num_experts = moe_num_experts
|
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
self.router = ReplicatedLinear(self.hidden_size,
|
||||||
|
self.num_total_experts,
|
||||||
|
bias=False)
|
||||||
self.w1 = nn.Parameter(
|
self.w1 = nn.Parameter(
|
||||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
torch.empty(self.num_total_experts,
|
||||||
|
self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
device=current_platform.device_type,
|
||||||
|
dtype=self.params_dtype))
|
||||||
self.w2 = nn.Parameter(
|
self.w2 = nn.Parameter(
|
||||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
torch.empty(self.num_total_experts,
|
||||||
self.activation_fn = get_act_fn(ffn_act_fn)
|
self.hidden_size,
|
||||||
|
self.intermediate_size,
|
||||||
|
device=current_platform.device_type,
|
||||||
|
dtype=self.params_dtype))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(self.hidden_size))
|
||||||
|
set_weight_attrs(self.w1, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.w2, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
def weight_loader(
|
||||||
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
|
self,
|
||||||
self.hidden_size)[expert_idx]
|
param: nn.Parameter,
|
||||||
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
|
loaded_weight: torch.Tensor,
|
||||||
self.hidden_size)[expert_idx]
|
weight_name: str,
|
||||||
|
):
|
||||||
|
# NOTE: Nomic-MoE has fused experts weights with shape
|
||||||
|
# (num_experts * intermediate_size, hidden_size)
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = self.intermediate_size
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
if weight_name.endswith("w1"):
|
||||||
|
loaded_weight = loaded_weight.reshape(
|
||||||
|
self.num_total_experts,
|
||||||
|
self.total_intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
)[:, shard]
|
||||||
|
if weight_name.endswith("w2"):
|
||||||
|
loaded_weight = loaded_weight.reshape(
|
||||||
|
self.num_total_experts,
|
||||||
|
self.total_intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
)[:, shard].transpose(1, 2)
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
x1 = x.matmul(expert_w1.t())
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
act_out = self.activation_fn(x1)
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
x2 = act_out.matmul(expert_w2)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
return x2
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
|
self.w1,
|
||||||
|
self.w2,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=False,
|
||||||
|
inplace=False,
|
||||||
|
activation=self.hidden_act,
|
||||||
|
is_act_and_mul=False)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
class NomicExperts(nn.Module):
|
return final_hidden_states.view(num_tokens, hidden_size) + self.bias
|
||||||
|
|
||||||
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
|
|
||||||
moe_num_experts: int):
|
|
||||||
super().__init__()
|
|
||||||
self.moe_num_experts = moe_num_experts
|
|
||||||
|
|
||||||
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
|
|
||||||
ffn_hidden_size=config.n_inner,
|
|
||||||
moe_num_experts=moe_num_experts,
|
|
||||||
ffn_act_fn=config.hidden_act)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(config.n_embd))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, weights: torch.Tensor,
|
|
||||||
top_weights: torch.Tensor,
|
|
||||||
top_experts: torch.LongTensor) -> torch.Tensor:
|
|
||||||
q_len, hidden_size = x.shape
|
|
||||||
x = x.view(-1, hidden_size)
|
|
||||||
out = torch.zeros_like(x)
|
|
||||||
|
|
||||||
expert_mask = nn.functional.one_hot(
|
|
||||||
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
|
||||||
for expert_idx in range(0, self.moe_num_experts):
|
|
||||||
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
|
||||||
if token_idx.shape[0] == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
token_list = token_idx.tolist()
|
|
||||||
topk_list = topk_idx.tolist()
|
|
||||||
|
|
||||||
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
|
||||||
expert_out = self.mlp(
|
|
||||||
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
|
|
||||||
None]
|
|
||||||
|
|
||||||
out.index_add_(0, token_idx, expert_out)
|
|
||||||
|
|
||||||
out = out.reshape(q_len, hidden_size)
|
|
||||||
return out + self.bias
|
|
||||||
|
|
||||||
|
|
||||||
class NomicMoELayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.router = NomicRouter(
|
|
||||||
config.n_embd,
|
|
||||||
moe_num_experts=config.num_experts,
|
|
||||||
moe_top_k=config.moe_top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = NomicExperts(
|
|
||||||
config,
|
|
||||||
hidden_size=config.n_embd,
|
|
||||||
ffn_hidden_size=config.n_inner,
|
|
||||||
moe_num_experts=config.num_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
weights, top_weights, top_experts = self.router(x)
|
|
||||||
out = self.experts(x, weights, top_weights, top_experts)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class BertWithRopeBlock(nn.Module):
|
class BertWithRopeBlock(nn.Module):
|
||||||
@ -332,7 +324,11 @@ class BertWithRopeBlock(nn.Module):
|
|||||||
prefix=f"{prefix}.attention")
|
prefix=f"{prefix}.attention")
|
||||||
|
|
||||||
if moe:
|
if moe:
|
||||||
self.mlp = NomicMoELayer(config=config, )
|
self.mlp = NomicMoE(num_experts=config.num_experts,
|
||||||
|
top_k=config.moe_top_k,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act)
|
||||||
else:
|
else:
|
||||||
if config.hidden_act in ["silu", "geglu"]:
|
if config.hidden_act in ["silu", "geglu"]:
|
||||||
self.mlp = BertWithRopeGatedMLP(
|
self.mlp = BertWithRopeGatedMLP(
|
||||||
@ -463,6 +459,10 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
if name.endswith((".w1", ".w2")):
|
||||||
|
# Nomic-MoE has fused experts weights
|
||||||
|
weight_loader(param, loaded_weight, name)
|
||||||
|
else:
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
@ -481,6 +481,10 @@ class NomicBertModel(BertWithRope):
|
|||||||
"mlp.fc12": "mlp.gate_proj",
|
"mlp.fc12": "mlp.gate_proj",
|
||||||
"mlp.fc2": "mlp.down_proj",
|
"mlp.fc2": "mlp.down_proj",
|
||||||
"norm2": "mlp_ln",
|
"norm2": "mlp_ln",
|
||||||
|
# MoE mapping
|
||||||
|
"experts.mlp.": "",
|
||||||
|
"experts.": "",
|
||||||
|
"router.layer": "router",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user