mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 04:57:04 +08:00
[Bugfix][Quantization] Fix FP8 + EP (#13784)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
51010a1807
commit
1e15aaef56
@ -260,7 +260,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_experts: int,
|
num_experts: int, # Global number of experts
|
||||||
top_k: int,
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
@ -291,7 +291,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.num_experts = num_experts # Global number of experts
|
self.global_num_experts = num_experts
|
||||||
|
self.local_num_experts = self.global_num_experts // self.ep_size
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@ -308,27 +309,29 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
# Create a tensor of size num_experts filled with -1
|
# Create a tensor of size num_experts filled with -1
|
||||||
self.expert_map = torch.full((self.num_experts, ),
|
self.expert_map = torch.full((self.global_num_experts, ),
|
||||||
-1,
|
-1,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
# Create a expert map for the local experts
|
# Create a expert map for the local experts
|
||||||
local_num_experts = num_experts // self.ep_size
|
|
||||||
ep_rank = get_tensor_model_parallel_rank()
|
ep_rank = get_tensor_model_parallel_rank()
|
||||||
if ep_rank < (self.ep_size - 1):
|
if ep_rank < (self.ep_size - 1):
|
||||||
# Each non-last rank gets local_num_experts experts.
|
# Each non-last rank gets local_num_experts experts.
|
||||||
self.expert_map[ep_rank * local_num_experts:
|
self.expert_map[ep_rank * self.local_num_experts:
|
||||||
(ep_rank + 1) * local_num_experts] = \
|
(ep_rank + 1) * self.local_num_experts] = \
|
||||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
torch.arange(0, self.local_num_experts, dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
# All remaining experts are assigned to the last rank.
|
# All remaining experts are assigned to the last rank.
|
||||||
local_num_experts = num_experts - ep_rank * local_num_experts
|
self.local_num_experts = (self.global_num_experts -
|
||||||
self.expert_map[-local_num_experts:] = \
|
ep_rank * self.local_num_experts)
|
||||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
self.expert_map[-self.local_num_experts:] = \
|
||||||
|
torch.arange(0, self.local_num_experts, dtype=torch.int32)
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
"non-grouped topk.")
|
"non-grouped topk.")
|
||||||
|
|
||||||
|
# Note: get_quant_method will look at the layer's local_num_experts
|
||||||
|
# for heuristic purposes, so it must be initialized first.
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
UnquantizedFusedMoEMethod())
|
UnquantizedFusedMoEMethod())
|
||||||
@ -336,11 +339,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
|
||||||
if self.expert_map is not None else num_experts
|
|
||||||
|
|
||||||
moe_quant_params = {
|
moe_quant_params = {
|
||||||
"num_experts": local_num_experts,
|
"num_experts": self.local_num_experts,
|
||||||
"hidden_size": hidden_size,
|
"hidden_size": hidden_size,
|
||||||
"intermediate_size_per_partition":
|
"intermediate_size_per_partition":
|
||||||
self.intermediate_size_per_partition,
|
self.intermediate_size_per_partition,
|
||||||
@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
global_num_experts=self.num_experts,
|
global_num_experts=self.global_num_experts,
|
||||||
expert_map=self.expert_map,
|
expert_map=self.expert_map,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
self.full_config).get_quant_method(layer, prefix)
|
self.full_config).get_quant_method(layer, prefix)
|
||||||
return AWQMarlinLinearMethod(self)
|
return AWQMarlinLinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
if layer.num_experts > 32:
|
if layer.local_num_experts > 32:
|
||||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||||
return MoeWNA16Config.from_config(
|
return MoeWNA16Config.from_config(
|
||||||
self.full_config).get_quant_method(layer, prefix)
|
self.full_config).get_quant_method(layer, prefix)
|
||||||
|
|||||||
@ -190,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.local_num_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
|
|||||||
@ -573,11 +573,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# Re-initialize w13_scale because we directly quantize
|
# Re-initialize w13_scale because we directly quantize
|
||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
layer.num_experts,
|
layer.local_num_experts,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=w13_weight.device),
|
device=w13_weight.device),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
for expert in range(layer.num_experts):
|
for expert in range(layer.local_num_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[
|
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||||
expert] = ops.scaled_fp8_quant(
|
expert] = ops.scaled_fp8_quant(
|
||||||
layer.w13_weight.data[expert, :, :])
|
layer.w13_weight.data[expert, :, :])
|
||||||
@ -644,7 +644,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.local_num_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
if layer.num_experts > 32:
|
if layer.local_num_experts > 32:
|
||||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||||
return MoeWNA16Config.from_config(
|
return MoeWNA16Config.from_config(
|
||||||
self.full_config).get_quant_method(layer, prefix)
|
self.full_config).get_quant_method(layer, prefix)
|
||||||
|
|||||||
@ -174,7 +174,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.local_num_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user