online fp8 quant with streaming weight post-processing (#29196)

Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
Vasiliy Kuznetsov 2025-12-08 15:15:10 -05:00 committed by GitHub
parent d1b5e7afbf
commit 0d402d2600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase):
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight)
@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
size_k_first = True
input_scale = None
# TODO(rob): refactor block quant into separate class.
@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"weight quantization block_k = {block_k}."
)
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()