mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 13:24:28 +08:00
[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM (#15587)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
parent
f4c98b4d4c
commit
43ed4143c4
@ -885,32 +885,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor, weight_name: str,
|
|
||||||
shard_id: str, expert_id: int) -> None:
|
|
||||||
param_data = param.data
|
|
||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
|
||||||
if "input_scale" in weight_name:
|
|
||||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
|
||||||
loaded_weight).abs() > 1e-5:
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}")
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# Weight scales
|
|
||||||
elif "weight_scale" in weight_name:
|
|
||||||
# If we are in merged column case (gate_up_proj)
|
|
||||||
if shard_id in ("w1", "w3"):
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
idx = 0 if shard_id == "w1" else 1
|
|
||||||
param_data[expert_id][idx] = loaded_weight
|
|
||||||
# If we are in the row parallel case (down_proj)
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|
||||||
s = (
|
s = (
|
||||||
|
|||||||
@ -268,14 +268,23 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
"input_activations")
|
"input_activations")
|
||||||
|
|
||||||
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
|
and self.input_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR)
|
||||||
|
per_channel = (
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||||
|
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||||
|
if not (per_tensor or per_channel):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
"For FP8 Fused MoE layers, we require per tensor "
|
||||||
"for weights and activations are supported. Found "
|
"or channelwise, dynamic per token quantization. Found "
|
||||||
f"{self.weight_quant}, {self.input_quant}")
|
f"{self.weight_quant}, {self.input_quant}")
|
||||||
|
|
||||||
self.static_input_scales = not self.input_quant.dynamic
|
self.static_input_scales = not self.input_quant.dynamic
|
||||||
|
if self.static_input_scales and per_channel:
|
||||||
|
raise ValueError(
|
||||||
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -303,24 +312,40 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||||
# They will be combined to a single scale after weight loading.
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
# They are combined to a single scale after weight loading.
|
||||||
2,
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
dtype=torch.float32),
|
num_experts, 2, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
||||||
dtype=torch.float32),
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
requires_grad=False)
|
num_experts,
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
2 * intermediate_size_per_partition,
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
1,
|
||||||
# to ensure the weight scales are loaded in properly
|
dtype=torch.float32),
|
||||||
extra_weight_attrs.update(
|
requires_grad=False)
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.static_input_scales:
|
if self.static_input_scales:
|
||||||
@ -362,6 +387,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
if self.static_input_scales:
|
if self.static_input_scales:
|
||||||
|
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"QuantConfig has static quantization, but found "
|
"QuantConfig has static quantization, but found "
|
||||||
@ -377,24 +403,25 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
|
||||||
# We take the max then dequant and requant each expert.
|
# for w13 per expert. Use max then dequant and requant each expert.
|
||||||
assert layer.w13_weight_scale is not None
|
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||||
shard_size = layer.intermediate_size_per_partition
|
assert layer.w13_weight_scale is not None
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
shard_size = layer.intermediate_size_per_partition
|
||||||
for expert_id in range(layer.local_num_experts):
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
start = 0
|
for expert_id in range(layer.local_num_experts):
|
||||||
for shard_id in range(2):
|
start = 0
|
||||||
dq_weight = per_tensor_dequantize(
|
for shard_id in range(2):
|
||||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
dq_weight = per_tensor_dequantize(
|
||||||
layer.w13_weight_scale[expert_id][shard_id])
|
layer.w13_weight[expert_id][start:start +
|
||||||
layer.w13_weight[expert_id][
|
shard_size, :],
|
||||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
dq_weight, max_w13_scales[expert_id])
|
layer.w13_weight[expert_id][
|
||||||
start += shard_size
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
|
dq_weight, max_w13_scales[expert_id])
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
start += shard_size
|
||||||
requires_grad=False)
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user