vllm/tests/models/test_utils.py
Harry Mellor d6953beb91
Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-05 07:06:22 -07:00

158 lines
5.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.models.utils import AutoWeightsLoader
pytestmark = pytest.mark.cpu_test
class ModuleWithBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(2)
def forward(self, x):
return self.bn(x)
class ModuleWithNestedBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested_mod = ModuleWithBatchNorm()
def forward(self, x):
return self.nested_mod(x)
def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()
new_mod = ModuleWithBatchNorm()
assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 1
def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"prefix.bn.weight": torch.Tensor([1, 2]),
"prefix.bn.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
"nested_mod.substr.weight": torch.Tensor([1, 2]),
"nested_mod.substr.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1