[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers 2025-12-09 16:54:32 -05:00 committed by GitHub
parent 00e5cbb967
commit fccd532587
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 206 additions and 86 deletions

View File

@ -10,10 +10,14 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
MODELS = [
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
),
)
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
):
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
if method_cls is Fp8MoEMethod and weight_block_size is None:
pytest.skip(
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
"If this is your use case, consider using a restore function like #26327"
)
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
weight_block_size=weight_block_size,
)
if method_cls is Fp8LinearMethod:
layer = torch.nn.Linear(1, 1)
method = method_cls(config)
method.create_weights(
layer=layer,
input_size_per_partition=1,
output_partition_sizes=[1],
input_size=1,
output_size=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
else:
layer = FusedMoE(
num_experts=1,
top_k=1,
hidden_size=1,
intermediate_size=1,
)
method = method_cls(config, layer)
method.create_weights(
layer=layer,
num_experts=1,
hidden_size=1,
intermediate_size_per_partition=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
method.use_marlin = use_marlin
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
for name, param in layer.named_parameters()
]
# test loading
for name, shape, _ in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
# test reloading works after loading
# assuming that no reshaping occurred
for name, shape, original_weight_loader in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert weight_loader is original_weight_loader
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)

View File

@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
assert not self.act_q_static
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
)
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter
del layer.weight_scale_inv
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
# If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
# Update layer with new values.
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
layer.input_scale = (
Parameter(input_scale, requires_grad=False)
if input_scale is not None
else None
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = (
process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if input_scale is not None:
replace_parameter(layer, "input_scale", input_scale)
else:
layer.input_scale = None
if self.use_marlin:
prepare_fp8_layer_for_marlin(
@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
w2_weight_scale_inv, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is

View File

@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# skip if there are no weights to process (for example, weight reloading)
if not hasattr(layer, "q_scale"):
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "prob_scale")
return
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to

View File

@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
@ -1404,12 +1405,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=layer.weight_scale.data,
ws=layer.weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
replace_parameter(layer, "weight", marlin_qweight)
# WEIGHT SCALES
# Permute scales
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
group_size = -1 if weight_block_size is None else weight_block_size[1]
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
)
if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "weight_scale"):
replace_parameter(layer, "weight_scale", marlin_scales)
elif hasattr(layer, "weight_scale_inv"):
replace_parameter(layer, "weight_scale_inv", marlin_scales)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin(

View File

@ -118,8 +118,11 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
#
# Extra note: upon weight reloading weight_scale.ndim == 0
unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
weight_scale.ndim != 0
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
)
# If unfused checkpoint, need requanize with the single scale.

View File

@ -50,6 +50,31 @@ def set_weight_attrs(
setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
This function should not be called on weights which are tied/shared
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
"""
# should not be used on a tied/shared param
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
if old_param is not None and hasattr(old_param, "weight_loader"):
weight_loader = old_param.weight_loader
set_weight_attrs(new_param, {"weight_loader": weight_loader})
setattr(layer, param_name, new_param)
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}