diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 800340ed6043c..ec3fc5ace17d8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter +from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig): return None +class CopyNumelCounter(TorchDispatchMode): + """ + Tracks total number of elements modified with `copy_`. Useful for keeping + track of weight loading where underlying weights can be arbitrarily + transformed (such as with `narrow`) before calling copy. + """ + + def __init__(self): + super().__init__() + self.copied_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out = func(*args, **kwargs) + if func == torch.ops.aten.copy_.default: + self.copied_numel += args[0].numel() + return out + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase): else: def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # load the current weight chunk - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - # track how many elements we have updated if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 - layer._loaded_numel += loaded_weight.numel() + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call # process_weights_after_loading @@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): new_extra_weight_attrs = extra_weight_attrs def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # load the current weight chunk - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - # add a counter to track how many elements we have updated if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 - layer._loaded_numel += loaded_weight.numel() + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call # process_weights_after_loading