[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM (#15587)

Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
Robert Shaw 2025-03-27 02:47:25 -04:00 committed by GitHub
parent f4c98b4d4c
commit 43ed4143c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 65 deletions

View File

@ -885,32 +885,6 @@ class FusedMoE(torch.nn.Module):
] ]
] ]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = ( s = (

View File

@ -268,14 +268,23 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR): and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError( raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales " "For FP8 Fused MoE layers, we require per tensor "
"for weights and activations are supported. Found " "or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}") f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
@ -303,22 +312,38 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading. # They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, w13_weight_scale = torch.nn.Parameter(torch.ones(
2, num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32), dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, num_experts, hidden_size, 1, dtype=torch.float32),
dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel) # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
@ -362,6 +387,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
if self.static_input_scales: if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None): if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError( raise ValueError(
"QuantConfig has static quantization, but found " "QuantConfig has static quantization, but found "
@ -377,8 +403,9 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert. # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# We take the max then dequant and requant each expert. # for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
@ -386,13 +413,13 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
start = 0 start = 0
for shard_id in range(2): for shard_id in range(2):
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :], layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][ layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant( start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]) dq_weight, max_w13_scales[expert_id])
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)