mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 07:09:08 +08:00
[Misc] Update Fused MoE weight loading (#7334)
This commit is contained in:
parent
fb377d7e74
commit
d3bdfd3ab9
@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
router_logits: torch.Tensor, top_k: int, renormalize: bool,
|
||||
use_grouped_topk: bool) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(x, layer.w13_weight, layer.w2_weight,
|
||||
router_logits, top_k, renormalize,
|
||||
use_grouped_topk, num_expert_group, topk_group)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
return fused_moe(x,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
return self.forward(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
|
||||
def forward_cuda(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
def forward_tpu(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
||||
return fused_moe(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk=top_k,
|
||||
gating_output=router_logits,
|
||||
renormalize=renormalize)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
@ -195,7 +201,126 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: int, expert_id: int):
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# Special case for fp8 scales.
|
||||
if getattr(param, "is_fp8_scale", False):
|
||||
self._load_fp8_scale(param.data, loaded_weight, weight_name,
|
||||
shard_id, expert_id)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# If transposed, weight is saved as [input_dim, output_dim]
|
||||
# Otherwise, weight is saved as [output_dim, input_dim]
|
||||
# Default is not transposed/input dim is dim 1
|
||||
input_dim = getattr(param, "input_dim", 1)
|
||||
output_dim = getattr(param, "output_dim", 0)
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
if shard_id == "w2":
|
||||
shard_dim = input_dim
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
elif shard_id in ("w1", "w3"):
|
||||
shard_dim = output_dim
|
||||
shard_size = expert_data.shape[output_dim] // 2
|
||||
offset = shard_size * tp_rank
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
|
||||
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
elif shard_id == "w3":
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
elif shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||
|
||||
@staticmethod
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, grouped_topk)
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
|
||||
for expert_id in range(num_experts) for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
@ -210,92 +335,11 @@ class FusedMoE(torch.nn.Module):
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
# shard_id 0 == gate_proj / w1
|
||||
# shard_id 2 == up_proj / w3
|
||||
if shard_id == 0 or shard_id == 2:
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == 0 else 1
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
# shard_id 1 == down_proj / w2
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weights
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.intermediate_size_per_partition
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
|
||||
# w1, gate_proj case: Load into first shard of w13.
|
||||
if shard_id == 0:
|
||||
param_data[expert_id,
|
||||
0:shard_size, :] = loaded_weight[shard, :]
|
||||
# w3, up_proj case: Load into second shard of w13.
|
||||
elif shard_id == 2:
|
||||
param_data[expert_id, shard_size:2 *
|
||||
shard_size, :] = loaded_weight[shard, :]
|
||||
# w2, down_proj case: Load into only shard of w2.
|
||||
elif shard_id == 1:
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int) -> List[Tuple[str, str, int, int]]:
|
||||
|
||||
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
gate_down_up = [
|
||||
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
|
||||
]
|
||||
|
||||
return [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_scale"
|
||||
if weight_name in gate_up else "experts.w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
||||
shard_id) for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
] + [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_weight"
|
||||
if weight_name in gate_up else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
] + [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.a13_scale"
|
||||
if weight_name in gate_up else "experts.a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
||||
shard_id) for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
]
|
||||
|
||||
@ -290,23 +290,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_weight_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
set_weight_attrs(w2_weight_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
@ -315,20 +321,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
|
||||
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("a13_scale", a13_scale)
|
||||
set_weight_attrs(a13_scale, extra_weight_attrs)
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
|
||||
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("a2_scale", a2_scale)
|
||||
set_weight_attrs(a2_scale, extra_weight_attrs)
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
else:
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
@ -341,16 +353,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_scale = torch.nn.Parameter(torch.ones(
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
layer.num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device),
|
||||
requires_grad=False)
|
||||
requires_grad=False)
|
||||
for expert in range(layer.num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_scale[
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], layer.w2_scale[
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight.data[expert, :, :])
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
@ -366,40 +378,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.a13_scale is None or layer.a2_scale is None:
|
||||
if (layer.w13_input_scale is None
|
||||
or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.a13_scale)
|
||||
or not all_close_1d(layer.a2_scale)):
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
print_warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_scale is not None
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_scale.max(dim=1).values
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
layer.w13_scale[expert_id][shard_id])
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(self,
|
||||
@ -407,27 +420,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
return fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_fp8=True,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
a1_scale=layer.a13_scale,
|
||||
a2_scale=layer.a2_scale,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@ -593,7 +593,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
||||
@ -930,7 +930,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
||||
@ -455,7 +455,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
||||
@ -492,7 +492,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user