mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 23:07:58 +08:00
[Misc] Update Fused MoE weight loading (#7334)
This commit is contained in:
parent
fb377d7e74
commit
d3bdfd3ab9
@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self,
|
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
layer: torch.nn.Module,
|
router_logits: torch.Tensor, top_k: int, renormalize: bool,
|
||||||
x: torch.Tensor,
|
use_grouped_topk: bool) -> torch.Tensor:
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
def apply(
|
def apply(self,
|
||||||
self,
|
layer: torch.nn.Module,
|
||||||
layer: torch.nn.Module,
|
x: torch.Tensor,
|
||||||
x: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
top_k: int,
|
||||||
top_k: int,
|
renormalize: bool,
|
||||||
renormalize: bool = True,
|
use_grouped_topk: bool,
|
||||||
use_grouped_topk: bool = False,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.forward(x, layer.w13_weight, layer.w2_weight,
|
|
||||||
router_logits, top_k, renormalize,
|
|
||||||
use_grouped_topk, num_expert_group, topk_group)
|
|
||||||
|
|
||||||
def forward_cuda(
|
return self.forward(x=x,
|
||||||
self,
|
layer=layer,
|
||||||
x: torch.Tensor,
|
router_logits=router_logits,
|
||||||
w1: torch.Tensor,
|
top_k=top_k,
|
||||||
w2: torch.Tensor,
|
renormalize=renormalize,
|
||||||
router_logits: torch.Tensor,
|
use_grouped_topk=use_grouped_topk,
|
||||||
top_k: int,
|
topk_group=topk_group,
|
||||||
renormalize: bool,
|
num_expert_group=num_expert_group)
|
||||||
use_grouped_topk: bool,
|
|
||||||
num_expert_group: Optional[int],
|
def forward_cuda(self,
|
||||||
topk_group: Optional[int],
|
layer: torch.nn.Module,
|
||||||
) -> torch.Tensor:
|
x: torch.Tensor,
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
use_grouped_topk: bool,
|
||||||
return fused_moe(x,
|
top_k: int,
|
||||||
w1,
|
router_logits: torch.Tensor,
|
||||||
w2,
|
renormalize: bool,
|
||||||
router_logits,
|
topk_group: Optional[int] = None,
|
||||||
top_k,
|
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||||
renormalize=renormalize,
|
|
||||||
inplace=True,
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
use_grouped_topk=use_grouped_topk,
|
fused_experts)
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group)
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group)
|
||||||
|
|
||||||
|
return fused_experts(hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
def forward_cpu(self, *args, **kwargs):
|
def forward_cpu(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"The CPU backend currently does not support MoE.")
|
"The CPU backend currently does not support MoE.")
|
||||||
|
|
||||||
def forward_tpu(
|
def forward_tpu(self,
|
||||||
self,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
use_grouped_topk: bool,
|
||||||
w2: torch.Tensor,
|
top_k: int,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
renormalize: bool,
|
||||||
renormalize: bool,
|
topk_group: Optional[int] = None,
|
||||||
use_grouped_topk: bool,
|
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||||
num_expert_group: Optional[int],
|
|
||||||
topk_group: Optional[int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||||
assert not use_grouped_topk
|
assert not use_grouped_topk
|
||||||
assert num_expert_group is None
|
assert num_expert_group is None
|
||||||
assert topk_group is None
|
assert topk_group is None
|
||||||
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
return fused_moe(hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk=top_k,
|
||||||
|
gating_output=router_logits,
|
||||||
|
renormalize=renormalize)
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
@ -195,7 +201,126 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def weight_loader(self, param: torch.nn.Parameter,
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor, weight_name: str,
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
shard_id: int, expert_id: int):
|
shard_id: str, expert_id: int) -> None:
|
||||||
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||||
|
f"got {shard_id}.")
|
||||||
|
|
||||||
|
# Special case for fp8 scales.
|
||||||
|
if getattr(param, "is_fp8_scale", False):
|
||||||
|
self._load_fp8_scale(param.data, loaded_weight, weight_name,
|
||||||
|
shard_id, expert_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
expert_data = param.data[expert_id]
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
# If transposed, weight is saved as [input_dim, output_dim]
|
||||||
|
# Otherwise, weight is saved as [output_dim, input_dim]
|
||||||
|
# Default is not transposed/input dim is dim 1
|
||||||
|
input_dim = getattr(param, "input_dim", 1)
|
||||||
|
output_dim = getattr(param, "output_dim", 0)
|
||||||
|
|
||||||
|
# Index the loaded weight for tp sharding.
|
||||||
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||||
|
if shard_id == "w2":
|
||||||
|
shard_dim = input_dim
|
||||||
|
shard_size = expert_data.shape[shard_dim]
|
||||||
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||||
|
elif shard_id in ("w1", "w3"):
|
||||||
|
shard_dim = output_dim
|
||||||
|
shard_size = expert_data.shape[output_dim] // 2
|
||||||
|
offset = shard_size * tp_rank
|
||||||
|
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
|
||||||
|
|
||||||
|
# Narrow parameter and load.
|
||||||
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
|
if shard_id == "w1":
|
||||||
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
# w3, up_proj: Load into second logical weight of w13.
|
||||||
|
elif shard_id == "w3":
|
||||||
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
# w2, down_proj: Load into only logical weight of w2.
|
||||||
|
elif shard_id == "w2":
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_experts(hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None):
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_topk, grouped_topk)
|
||||||
|
|
||||||
|
# DeekSeekv2 uses grouped_top_k
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group)
|
||||||
|
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
||||||
|
|
||||||
|
return [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("experts.w13_" if weight_name
|
||||||
|
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||||
|
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
|
||||||
|
for expert_id in range(num_experts) for shard_id, weight_name in [
|
||||||
|
("w1", ckpt_gate_proj_name),
|
||||||
|
("w2", ckpt_down_proj_name),
|
||||||
|
("w3", ckpt_up_proj_name),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
|
shard_id: str, expert_id: int) -> None:
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
# Input scales can be loaded directly and should be equal.
|
||||||
@ -210,92 +335,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# Weight scales
|
# Weight scales
|
||||||
elif "weight_scale" in weight_name:
|
elif "weight_scale" in weight_name:
|
||||||
# If we are in merged column case (gate_up_proj)
|
# If we are in merged column case (gate_up_proj)
|
||||||
# shard_id 0 == gate_proj / w1
|
if shard_id in ("w1", "w3"):
|
||||||
# shard_id 2 == up_proj / w3
|
|
||||||
if shard_id == 0 or shard_id == 2:
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
idx = 0 if shard_id == 0 else 1
|
idx = 0 if shard_id == "w1" else 1
|
||||||
param_data[expert_id][idx] = loaded_weight
|
param_data[expert_id][idx] = loaded_weight
|
||||||
# If we are in the row parallel case (down_proj)
|
# If we are in the row parallel case (down_proj)
|
||||||
# shard_id 1 == down_proj / w2
|
|
||||||
else:
|
else:
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
# Weights
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard_size = self.intermediate_size_per_partition
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
|
|
||||||
# w1, gate_proj case: Load into first shard of w13.
|
|
||||||
if shard_id == 0:
|
|
||||||
param_data[expert_id,
|
|
||||||
0:shard_size, :] = loaded_weight[shard, :]
|
|
||||||
# w3, up_proj case: Load into second shard of w13.
|
|
||||||
elif shard_id == 2:
|
|
||||||
param_data[expert_id, shard_size:2 *
|
|
||||||
shard_size, :] = loaded_weight[shard, :]
|
|
||||||
# w2, down_proj case: Load into only shard of w2.
|
|
||||||
elif shard_id == 1:
|
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Shard id must be in [0,1,2] but got {shard_id}")
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor):
|
|
||||||
assert self.quant_method is not None
|
|
||||||
|
|
||||||
# Matrix multiply.
|
|
||||||
final_hidden_states = self.quant_method.apply(
|
|
||||||
self,
|
|
||||||
x=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=self.top_k,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
topk_group=self.topk_group)
|
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_expert_params_mapping(
|
|
||||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
|
||||||
ckpt_up_proj_name: str,
|
|
||||||
num_experts: int) -> List[Tuple[str, str, int, int]]:
|
|
||||||
|
|
||||||
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
|
||||||
gate_down_up = [
|
|
||||||
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
|
||||||
# These are the weight scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
("experts.w13_scale"
|
|
||||||
if weight_name in gate_up else "experts.w2_scale",
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
|
||||||
shard_id) for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
] + [
|
|
||||||
# These are the weights for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
("experts.w13_weight"
|
|
||||||
if weight_name in gate_up else "experts.w2_weight",
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
] + [
|
|
||||||
# These are the weight scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
("experts.a13_scale"
|
|
||||||
if weight_name in gate_up else "experts.a2_scale",
|
|
||||||
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
|
||||||
shard_id) for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
]
|
|
||||||
|
|||||||
@ -290,23 +290,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
# They will be combined to a single scale after weight loading.
|
# They will be combined to a single scale after weight loading.
|
||||||
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
2,
|
2,
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_scale", w13_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w2_scale", w2_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
# process_weights_after_loading()
|
# process_weights_after_loading()
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, {
|
||||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
"is_fp8_scale": True,
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
set_weight_attrs(w2_weight_scale, {
|
||||||
|
"is_fp8_scale": True,
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
@ -315,20 +321,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"Found static activation scheme for checkpoint that "
|
"Found static activation scheme for checkpoint that "
|
||||||
"was not serialized fp8.")
|
"was not serialized fp8.")
|
||||||
|
|
||||||
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
dtype=torch.float32),
|
num_experts, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("a13_scale", a13_scale)
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
set_weight_attrs(a13_scale, extra_weight_attrs)
|
set_weight_attrs(w13_input_scale, {
|
||||||
|
"is_fp8_scale": True,
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
|
||||||
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
dtype=torch.float32),
|
num_experts, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("a2_scale", a2_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
set_weight_attrs(a2_scale, extra_weight_attrs)
|
set_weight_attrs(w2_input_scale, {
|
||||||
|
"is_fp8_scale": True,
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
layer.a13_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.a2_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
@ -341,16 +353,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
# Re-initialize w13_scale because we directly quantize
|
||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
layer.w13_scale = torch.nn.Parameter(torch.ones(
|
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=w13_weight.device),
|
device=w13_weight.device),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
for expert in range(layer.num_experts):
|
for expert in range(layer.num_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_scale[
|
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||||
expert] = ops.scaled_fp8_quant(
|
expert] = ops.scaled_fp8_quant(
|
||||||
layer.w13_weight.data[expert, :, :])
|
layer.w13_weight.data[expert, :, :])
|
||||||
w2_weight[expert, :, :], layer.w2_scale[
|
w2_weight[expert, :, :], layer.w2_weight_scale[
|
||||||
expert] = ops.scaled_fp8_quant(
|
expert] = ops.scaled_fp8_quant(
|
||||||
layer.w2_weight.data[expert, :, :])
|
layer.w2_weight.data[expert, :, :])
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
@ -366,40 +378,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
if layer.a13_scale is None or layer.a2_scale is None:
|
if (layer.w13_input_scale is None
|
||||||
|
or layer.w2_input_scale is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"QuantConfig has static quantization, but found "
|
"QuantConfig has static quantization, but found "
|
||||||
"activation scales are None.")
|
"activation scales are None.")
|
||||||
if (not all_close_1d(layer.a13_scale)
|
if (not all_close_1d(layer.w13_input_scale)
|
||||||
or not all_close_1d(layer.a2_scale)):
|
or not all_close_1d(layer.w2_input_scale)):
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Found input_scales that are not equal for "
|
"Found input_scales that are not equal for "
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer. ")
|
"for each layer. ")
|
||||||
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
|
layer.w13_input_scale = torch.nn.Parameter(
|
||||||
requires_grad=False)
|
layer.w13_input_scale.max(), requires_grad=False)
|
||||||
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
# We take the max then dequant and requant each expert.
|
# We take the max then dequant and requant each expert.
|
||||||
assert layer.w13_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.num_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
layer.w13_weight[expert_id][start:start +
|
layer.w13_weight[expert_id][start:start +
|
||||||
shard_size, :],
|
shard_size, :],
|
||||||
layer.w13_scale[expert_id][shard_id])
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
layer.w13_weight[expert_id][
|
layer.w13_weight[expert_id][
|
||||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
dq_weight, max_w13_scales[expert_id])
|
dq_weight, max_w13_scales[expert_id])
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
@ -407,27 +420,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool,
|
||||||
num_expert_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
return fused_moe(x,
|
|
||||||
layer.w13_weight,
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
layer.w2_weight,
|
hidden_states=x,
|
||||||
router_logits,
|
router_logits=router_logits,
|
||||||
top_k,
|
use_grouped_topk=use_grouped_topk,
|
||||||
renormalize=renormalize,
|
top_k=top_k,
|
||||||
inplace=True,
|
renormalize=renormalize,
|
||||||
use_fp8=True,
|
topk_group=topk_group,
|
||||||
w1_scale=layer.w13_scale,
|
num_expert_group=num_expert_group)
|
||||||
w2_scale=layer.w2_scale,
|
|
||||||
a1_scale=layer.a13_scale,
|
return fused_experts(x,
|
||||||
a2_scale=layer.a2_scale,
|
layer.w13_weight,
|
||||||
use_grouped_topk=use_grouped_topk,
|
layer.w2_weight,
|
||||||
num_expert_group=num_expert_group,
|
topk_weights=topk_weights,
|
||||||
topk_group=topk_group)
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8=True,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@ -593,7 +593,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -930,7 +930,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -455,7 +455,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -492,7 +492,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user