[Bugfix] Fix Crashing When Loading Modules With Batchnorm Stats (#15813)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2025-03-31 07:23:53 -06:00 committed by GitHub
parent 3aa2b6a637
commit c2e7507ad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 0 deletions

View File

@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.model_executor.models.utils import AutoWeightsLoader
class ModuleWithBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(2)
def forward(self, x):
return self.bn(x)
class ModuleWithNestedBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested_mod = ModuleWithBatchNorm()
def forward(self, x):
return self.nested_mod(x)
def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()
new_mod = ModuleWithBatchNorm()
assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 1
def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1

View File

@ -158,6 +158,26 @@ class AutoWeightsLoader:
yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
"""
if isinstance(module, (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
)):
module_state_dict = module.state_dict()
for stat_name in ("running_mean", "running_var",
"num_batches_tracked"):
child_params[stat_name] = module_state_dict[stat_name]
def _load_module(
self,
base_prefix: str,
@ -186,6 +206,10 @@ class AutoWeightsLoader:
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
# Add missing tensors the weight loader needs to be able to load
# that aren't registered as params, e.g., batchnorm statistics.
self._add_loadable_non_param_tensors(module, child_params)
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)