mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Model] use FusedMoE layer in Jamba (#6935)
This commit is contained in:
parent
daed30c4a9
commit
2ee8d3ba55
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
"""Inference-only Jurassic model."""
|
||||
"""Inference-only Jamba model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
@ -15,10 +15,9 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -282,108 +281,50 @@ class JambaMLP(nn.Module):
|
||||
|
||||
|
||||
class JambaMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||
across all ranks.
|
||||
|
||||
Each expert's weights are sharded across all ranks and a fused MoE
|
||||
kernel is used for the forward pass, and finally we reduce the outputs
|
||||
across ranks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JambaConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
num_experts: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_total_experts = num_experts or config.num_experts
|
||||
self.top_k = top_k or config.num_experts_per_tok
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size // self.tp_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if self.num_total_experts > 1:
|
||||
self.router = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype=params_dtype)
|
||||
|
||||
self.router = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype)
|
||||
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype,
|
||||
))
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype,
|
||||
))
|
||||
|
||||
set_weight_attrs(
|
||||
self.ws,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.w2s,
|
||||
{
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
expert_id: int,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("gate_proj.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("up_proj.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("down_proj.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
self.experts = FusedMoE(self.num_total_experts,
|
||||
self.top_k,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
tp_size=tp_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
|
||||
final_hidden_states = fused_moe(
|
||||
hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=
|
||||
False, # Mixtral normalize the expert probs to 1. We don't!
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
if self.num_total_experts > 1:
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
else:
|
||||
router_logits = torch.ones((hidden_states.shape[0], 1),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
hidden_states = self.experts(hidden_states, router_logits)
|
||||
return hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class JambaMambaDecoderLayer(nn.Module):
|
||||
@ -917,15 +858,13 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# (param_name, weight_name, expert_id)
|
||||
(
|
||||
"ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
) for expert_id in range(self.config.num_experts)
|
||||
for weight_name in ["down_proj", "up_proj", "gate_proj"]
|
||||
]
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
@ -952,7 +891,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@ -961,6 +901,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user