From 0d402d2600490bac17bc5d079e89b1136fe37eda Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 8 Dec 2025 15:15:10 -0500 Subject: [PATCH] online fp8 quant with streaming weight post-processing (#29196) Signed-off-by: vasiliy --- .../model_executor/layers/quantization/fp8.py | 67 ++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0e3e13f5945ea..419ddd91b64e0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase): output_size_per_partition, input_size_per_partition, weight_loader ) else: + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + return res + # For non-serialized checkpoints, use original dtype weight = ModelWeightParameter( data=torch.empty( @@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase): ), input_dim=1, output_dim=0, - weight_loader=weight_loader, + weight_loader=patched_weight_loader, ) layer.register_parameter("weight", weight) @@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase): layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + size_k_first = True input_scale = None # TODO(rob): refactor block quant into separate class. @@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase): f"weight quantization block_k = {block_k}." ) + # if we are doing online quantization, patch the weight + # loaded to call `process_weights_after_loading` in a streaming fashion + # as soon as the last weight chunk is loaded + if not self.quant_config.is_checkpoint_fp8_serialized: + weight_loader = extra_weight_attrs["weight_loader"] + # create a new holder to prevent modifying behavior of any other + # objects which might depend on the old one + new_extra_weight_attrs = extra_weight_attrs + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # add a counter to track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call + # from doing anything + layer._already_called_process_weights_after_loading = True + + return res + + new_extra_weight_attrs["weight_loader"] = patched_weight_loader + extra_weight_attrs = new_extra_weight_attrs + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.rocm_aiter_moe_enabled = False def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + # Lazy import to avoid importing triton too early. self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()