mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 19:05:41 +08:00
fix fp8 online quantization streaming with tp > 1 (#30900)
Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
parent
9a5e96523b
commit
f4ee2c3d90
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
class CopyNumelCounter(TorchDispatchMode):
|
||||
"""
|
||||
Tracks total number of elements modified with `copy_`. Useful for keeping
|
||||
track of weight loading where underlying weights can be arbitrarily
|
||||
transformed (such as with `narrow`) before calling copy.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.copied_numel = 0
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
out = func(*args, **kwargs)
|
||||
if func == torch.ops.aten.copy_.default:
|
||||
self.copied_numel += args[0].numel()
|
||||
return out
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
new_extra_weight_attrs = extra_weight_attrs
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# add a counter to track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user