mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 02:07:04 +08:00
online fp8 quant with streaming weight post-processing (#29196)
Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
parent
d1b5e7afbf
commit
0d402d2600
@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
output_size_per_partition, input_size_per_partition, weight_loader
|
output_size_per_partition, input_size_per_partition, weight_loader
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||||
|
# load the current weight chunk
|
||||||
|
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||||
|
|
||||||
|
# track how many elements we have updated
|
||||||
|
if not hasattr(layer, "_loaded_numel"):
|
||||||
|
layer._loaded_numel = 0
|
||||||
|
layer._loaded_numel += loaded_weight.numel()
|
||||||
|
|
||||||
|
# if we have loaded all of the elements, call
|
||||||
|
# process_weights_after_loading
|
||||||
|
target_loaded_numel = layer.weight.numel()
|
||||||
|
if layer._loaded_numel == target_loaded_numel:
|
||||||
|
self.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# Delete the bookkeeping
|
||||||
|
del layer._loaded_numel
|
||||||
|
# Prevent the usual `process_weights_after_loading` call from doing
|
||||||
|
# anything
|
||||||
|
layer._already_called_process_weights_after_loading = True
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
# For non-serialized checkpoints, use original dtype
|
# For non-serialized checkpoints, use original dtype
|
||||||
weight = ModelWeightParameter(
|
weight = ModelWeightParameter(
|
||||||
data=torch.empty(
|
data=torch.empty(
|
||||||
@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
),
|
),
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader,
|
weight_loader=patched_weight_loader,
|
||||||
)
|
)
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("input_scale", None)
|
layer.register_parameter("input_scale", None)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
|
return
|
||||||
|
|
||||||
size_k_first = True
|
size_k_first = True
|
||||||
input_scale = None
|
input_scale = None
|
||||||
# TODO(rob): refactor block quant into separate class.
|
# TODO(rob): refactor block quant into separate class.
|
||||||
@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
f"weight quantization block_k = {block_k}."
|
f"weight quantization block_k = {block_k}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if we are doing online quantization, patch the weight
|
||||||
|
# loaded to call `process_weights_after_loading` in a streaming fashion
|
||||||
|
# as soon as the last weight chunk is loaded
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
weight_loader = extra_weight_attrs["weight_loader"]
|
||||||
|
# create a new holder to prevent modifying behavior of any other
|
||||||
|
# objects which might depend on the old one
|
||||||
|
new_extra_weight_attrs = extra_weight_attrs
|
||||||
|
|
||||||
|
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||||
|
# load the current weight chunk
|
||||||
|
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||||
|
|
||||||
|
# add a counter to track how many elements we have updated
|
||||||
|
if not hasattr(layer, "_loaded_numel"):
|
||||||
|
layer._loaded_numel = 0
|
||||||
|
layer._loaded_numel += loaded_weight.numel()
|
||||||
|
|
||||||
|
# if we have loaded all of the elements, call
|
||||||
|
# process_weights_after_loading
|
||||||
|
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||||
|
if layer._loaded_numel == target_loaded_numel:
|
||||||
|
self.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# Delete the bookkeeping
|
||||||
|
del layer._loaded_numel
|
||||||
|
# Prevent the usual `process_weights_after_loading` call
|
||||||
|
# from doing anything
|
||||||
|
layer._already_called_process_weights_after_loading = True
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||||
|
extra_weight_attrs = new_extra_weight_attrs
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.rocm_aiter_moe_enabled = False
|
self.rocm_aiter_moe_enabled = False
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
|
return
|
||||||
|
|
||||||
# Lazy import to avoid importing triton too early.
|
# Lazy import to avoid importing triton too early.
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user