mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 19:02:26 +08:00
[Misc][Breaking] Change FP8 checkpoint format from act_scale -> input_scale (#5353)
This commit is contained in:
parent
8ea5e44a43
commit
c09dade2a2
@ -171,10 +171,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
**extra_weight_attrs)
|
**extra_weight_attrs)
|
||||||
|
|
||||||
# ACTIVATION SCALE
|
# INPUT ACTIVATION SCALE
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
self._create_scale_param(
|
self._create_scale_param(
|
||||||
scale_name="act_scale",
|
scale_name="input_scale",
|
||||||
layer=layer,
|
layer=layer,
|
||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
**extra_weight_attrs)
|
**extra_weight_attrs)
|
||||||
@ -207,7 +207,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
layer.logical_widths = None
|
layer.logical_widths = None
|
||||||
layer.act_scale = None
|
layer.input_scale = None
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp8, requantize the separately quantized logical
|
# If checkpoint is fp8, requantize the separately quantized logical
|
||||||
@ -232,18 +232,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
# ACT_SCALE
|
# INPUT ACTIVATION SCALE
|
||||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||||
# Static: set to max of the act_scales (since they are equal).
|
# Static: set to max of the input_scales (since they are equal).
|
||||||
if self.quant_config.activation_scheme == "dynamic":
|
if self.quant_config.activation_scheme == "dynamic":
|
||||||
layer.act_scale = None
|
layer.input_scale = None
|
||||||
elif self.quant_config.activation_scheme == "static":
|
elif self.quant_config.activation_scheme == "static":
|
||||||
if not all_close_1d(layer.act_scale):
|
if not all_close_1d(layer.input_scale):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"All the act_scales for the logical weights of a layer "
|
"All the input_scales for the logical weights of a "
|
||||||
f"must be equal. But got {layer.act_scale}")
|
f"layer must be equal. But got {layer.input_scale}")
|
||||||
layer.act_scale = Parameter(layer.act_scale.max(),
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||||
@ -254,11 +254,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
|
||||||
if bias is None and self.cutlass_fp8_supported:
|
if bias is None and self.cutlass_fp8_supported:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output = ops.cutlass_scaled_mm_dq(
|
output = ops.cutlass_scaled_mm_dq(
|
||||||
@ -271,7 +271,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||||
layer.act_scale,
|
layer.input_scale,
|
||||||
batch_dim_padding=17)
|
batch_dim_padding=17)
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
# Fused GEMM_DQ -- note we padded the input above because
|
||||||
|
|||||||
@ -147,7 +147,7 @@ class MixtralMoE(nn.Module):
|
|||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
|
|
||||||
# ACT_SCALE (for fp8)
|
# INPUT_SCALE (for fp8)
|
||||||
if quant_config.activation_scheme == "static":
|
if quant_config.activation_scheme == "static":
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
if not quant_config.is_checkpoint_fp8_serialized:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -182,11 +182,11 @@ class MixtralMoE(nn.Module):
|
|||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
|
||||||
# Loading scales
|
# Loading scales
|
||||||
if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
|
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
||||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||||
loaded_weight).abs() > 1e-5:
|
loaded_weight).abs() > 1e-5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"act_scales of w1 and w3 of a layer "
|
"input_scales of w1 and w3 of a layer "
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
f"vs. {loaded_weight}")
|
f"vs. {loaded_weight}")
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
@ -225,9 +225,9 @@ class MixtralMoE(nn.Module):
|
|||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
# If checkpoint is fp8 + static, cleanup input_scales.
|
||||||
# Since state_dict has an act_scale per expert but our kernels
|
# Since state_dict has an input_scale per expert but our kernels
|
||||||
# are passed one act_scale shared across all experts.
|
# are passed one input_scale shared across all experts.
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
if self.a13_scale is None or self.a2_scale is None:
|
if self.a13_scale is None or self.a2_scale is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -237,7 +237,7 @@ class MixtralMoE(nn.Module):
|
|||||||
if (not all_close_1d(self.a13_scale)
|
if (not all_close_1d(self.a13_scale)
|
||||||
or not all_close_1d(self.a2_scale)):
|
or not all_close_1d(self.a2_scale)):
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Found act_scales that are not equal for "
|
"Found input_scales that are not equal for "
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer. ")
|
"for each layer. ")
|
||||||
|
|
||||||
@ -576,7 +576,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
# These are the activation scales for the experts
|
# These are the activation scales for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user