[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name (#18358)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-20 11:20:12 +08:00 committed by GitHub
parent d565e0976f
commit f07a673eb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 116 additions and 109 deletions

View File

@ -77,3 +77,73 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"prefix.bn.weight": torch.Tensor([1, 2]),
"prefix.bn.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
"nested_mod.substr.weight": torch.Tensor([1, 2]),
"nested_mod.substr.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1

View File

@ -478,18 +478,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = [
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
]
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head.weight")
skip_prefixes = (["lm_head."]
if self.config.tie_word_embeddings else None)
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
)
return loader.load_weights(weights)

View File

@ -550,10 +550,12 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head")
skip_prefixes = (["lm_head"]
if self.config.tie_word_embeddings else None)
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
)
return loader.load_weights(weights)

View File

@ -482,5 +482,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -447,8 +447,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -502,14 +502,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -382,19 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
"lm_head.weight"
] if self.config.tie_word_embeddings else [
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -403,19 +403,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
"lm_head.weight"
] if self.config.tie_word_embeddings else [
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -442,8 +442,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["rotary_emb.inv_freq"],
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -344,14 +344,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -1228,9 +1228,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
loader = AutoWeightsLoader(self, skip_substrs=["lora"])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -660,8 +660,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -535,8 +535,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -530,8 +530,5 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -500,14 +500,5 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -338,13 +338,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=[
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
"rotary_emb.sin_cached"
],
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -349,8 +349,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=([
"rotary_emb.inv_freq", "lm_head.weight"
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -80,18 +80,30 @@ class AutoWeightsLoader:
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
"""
# Models trained using early version ColossalAI
# may include these tensors in checkpoint. Skip them.
ROTARY_EMBEDS_UNUSED_WEIGHTS = [
"rotary_emb.inv_freq",
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
]
def __init__(
self,
module: nn.Module,
*,
skip_prefixes: Optional[list[str]] = None,
skip_substrs: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.skip_substrs = skip_substrs or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
# update default skip_substrs
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
def _groupby_prefix(
self,
@ -119,7 +131,8 @@ class AutoWeightsLoader:
return ".".join((prefix, rest))
def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
return (any(qualname.startswith(p) for p in self.skip_prefixes)
or any(substr in qualname for substr in self.skip_substrs))
def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
@ -257,6 +270,9 @@ class AutoWeightsLoader:
) -> set[str]:
if mapper is not None:
weights = mapper.apply(weights)
# filter out weights with first-prefix/substr to skip in name
weights = ((name, weight) for name, weight in weights
if not self._can_skip(name))
autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights