fix fp8 online quantization streaming with tp > 1 (#30900)

Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
Vasiliy Kuznetsov 2025-12-18 11:45:15 -05:00 committed by GitHub
parent 9a5e96523b
commit f4ee2c3d90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig):
return None
class CopyNumelCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""
def __init__(self):
super().__init__()
self.copied_numel = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*args, **kwargs)
if func == torch.ops.aten.copy_.default:
self.copied_numel += args[0].numel()
return out
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase):
else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading