[Model] use FusedMoE layer in Jamba (#6935)

This commit is contained in:
Avshalom Manevich 2024-07-31 22:00:24 +03:00 committed by GitHub
parent daed30c4a9
commit 2ee8d3ba55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,5 @@
# coding=utf-8
"""Inference-only Jurassic model."""
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
@ -15,10 +15,9 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@ -282,108 +281,50 @@ class JambaMLP(nn.Module):
class JambaMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
config: JambaConfig,
num_experts: Optional[int] = None,
top_k: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.num_total_experts = num_experts or config.num_experts
self.top_k = top_k or config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // self.tp_size
self.intermediate_size = config.intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if self.num_total_experts > 1:
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None,
params_dtype=params_dtype)
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))
set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("gate_proj.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("up_proj.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("down_proj.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
self.experts = FusedMoE(self.num_total_experts,
self.top_k,
self.hidden_size,
self.intermediate_size,
tp_size=tp_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
use_grouped_topk=False,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.router(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=
False, # Mixtral normalize the expert probs to 1. We don't!
inplace=True,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
if self.num_total_experts > 1:
router_logits, _ = self.router(hidden_states)
else:
router_logits = torch.ones((hidden_states.shape[0], 1),
device=hidden_states.device,
dtype=hidden_states.dtype)
hidden_states = self.experts(hidden_states, router_logits)
return hidden_states.view(orig_shape)
class JambaMambaDecoderLayer(nn.Module):
@ -917,15 +858,13 @@ class JambaForCausalLM(nn.Module, HasInnerState):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
(
"ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
) for expert_id in range(self.config.num_experts)
for weight_name in ["down_proj", "up_proj", "gate_proj"]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
@ -952,7 +891,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@ -961,6 +901,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
else: