mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[Bugfix] Fix Crashing When Loading Modules With Batchnorm Stats (#15813)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
3aa2b6a637
commit
c2e7507ad4
79
tests/models/test_utils.py
Normal file
79
tests/models/test_utils.py
Normal file
@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class ModuleWithBatchNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn = torch.nn.BatchNorm1d(2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(x)
|
||||
|
||||
|
||||
class ModuleWithNestedBatchNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.nested_mod = ModuleWithBatchNorm()
|
||||
|
||||
def forward(self, x):
|
||||
return self.nested_mod(x)
|
||||
|
||||
|
||||
def test_module_with_batchnorm_can_load():
|
||||
"""Ensure the auto weight loader can load batchnorm stats."""
|
||||
mod = ModuleWithBatchNorm()
|
||||
# Run some data through the module with batchnorm
|
||||
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||
|
||||
# Try to load the weights to a new instance
|
||||
def weight_generator():
|
||||
yield from mod.state_dict().items()
|
||||
|
||||
new_mod = ModuleWithBatchNorm()
|
||||
|
||||
assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
||||
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
||||
assert new_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod)
|
||||
loader.load_weights(weight_generator())
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
||||
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
||||
assert new_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
# Run some data through the module with batchnorm
|
||||
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||
|
||||
# Try to load the weights to a new instance
|
||||
def weight_generator():
|
||||
yield from mod.state_dict().items()
|
||||
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod)
|
||||
loader.load_weights(weight_generator())
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
@ -158,6 +158,26 @@ class AutoWeightsLoader:
|
||||
|
||||
yield weight_qualname
|
||||
|
||||
def _add_loadable_non_param_tensors(self, module: nn.Module,
|
||||
child_params: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
"""
|
||||
if isinstance(module, (
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.LazyBatchNorm1d,
|
||||
nn.LazyBatchNorm2d,
|
||||
nn.LazyBatchNorm3d,
|
||||
nn.SyncBatchNorm,
|
||||
)):
|
||||
module_state_dict = module.state_dict()
|
||||
for stat_name in ("running_mean", "running_var",
|
||||
"num_batches_tracked"):
|
||||
child_params[stat_name] = module_state_dict[stat_name]
|
||||
|
||||
def _load_module(
|
||||
self,
|
||||
base_prefix: str,
|
||||
@ -186,6 +206,10 @@ class AutoWeightsLoader:
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
|
||||
# Add missing tensors the weight loader needs to be able to load
|
||||
# that aren't registered as params, e.g., batchnorm statistics.
|
||||
self._add_loadable_non_param_tensors(module, child_params)
|
||||
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user