mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 00:01:31 +08:00
online fp8 quant with streaming weight post-processing (#29196)
Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
parent
d1b5e7afbf
commit
0d402d2600
@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
else:
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
size_k_first = True
|
||||
input_scale = None
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# if we are doing online quantization, patch the weight
|
||||
# loaded to call `process_weights_after_loading` in a streaming fashion
|
||||
# as soon as the last weight chunk is loaded
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
# create a new holder to prevent modifying behavior of any other
|
||||
# objects which might depend on the old one
|
||||
new_extra_weight_attrs = extra_weight_attrs
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# add a counter to track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
extra_weight_attrs = new_extra_weight_attrs
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user