mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name (#18358)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
d565e0976f
commit
f07a673eb2
@ -77,3 +77,73 @@ def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
def test_module_skip_prefix():
|
||||
"""Ensure the auto weight loader can skip prefix."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
# Run some data through the module with batchnorm
|
||||
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||
|
||||
# Try to load the weights to a new instance
|
||||
def weight_generator():
|
||||
# weights needed to be filtered out
|
||||
redundant_weights = {
|
||||
"prefix.bn.weight": torch.Tensor([1, 2]),
|
||||
"prefix.bn.bias": torch.Tensor([3, 4]),
|
||||
}
|
||||
yield from (mod.state_dict() | redundant_weights).items()
|
||||
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
|
||||
loader.load_weights(weight_generator())
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
def test_module_skip_substr():
|
||||
"""Ensure the auto weight loader can skip prefix."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
# Run some data through the module with batchnorm
|
||||
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||
|
||||
# Try to load the weights to a new instance
|
||||
def weight_generator():
|
||||
# weights needed to be filtered out
|
||||
redundant_weights = {
|
||||
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
|
||||
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
|
||||
"nested_mod.substr.weight": torch.Tensor([1, 2]),
|
||||
"nested_mod.substr.bias": torch.Tensor([3, 4]),
|
||||
}
|
||||
yield from (mod.state_dict() | redundant_weights).items()
|
||||
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
|
||||
loader.load_weights(weight_generator())
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
@ -478,18 +478,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
skip_prefixes = [
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached",
|
||||
]
|
||||
# With tie_word_embeddings, we can skip lm_head.weight
|
||||
# The weight might appear unnecessarily in the files if the model is
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
if self.config.tie_word_embeddings:
|
||||
skip_prefixes.append("lm_head.weight")
|
||||
skip_prefixes = (["lm_head."]
|
||||
if self.config.tie_word_embeddings else None)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=skip_prefixes,
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -550,10 +550,12 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
skip_prefixes = ["rotary_emb.inv_freq"]
|
||||
# Skip lm_head when tie_word_embeddings is True
|
||||
if self.config.tie_word_embeddings:
|
||||
skip_prefixes.append("lm_head")
|
||||
skip_prefixes = (["lm_head"]
|
||||
if self.config.tie_word_embeddings else None)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=skip_prefixes,
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -482,5 +482,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -447,8 +447,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -502,14 +502,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -382,19 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached",
|
||||
"lm_head.weight"
|
||||
] if self.config.tie_word_embeddings else [
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
]),
|
||||
skip_prefixes=(["lm_head.weight"]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -403,19 +403,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached",
|
||||
"lm_head.weight"
|
||||
] if self.config.tie_word_embeddings else [
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
]),
|
||||
skip_prefixes=(["lm_head.weight"]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -442,8 +442,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=["rotary_emb.inv_freq"],
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -344,14 +344,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -1228,9 +1228,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = ((name, data) for name, data in weights
|
||||
if "lora" not in name)
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader = AutoWeightsLoader(self, skip_substrs=["lora"])
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
|
||||
@ -660,8 +660,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -535,8 +535,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -530,8 +530,5 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -500,14 +500,5 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq",
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
]),
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -338,13 +338,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
skip_prefixes=[
|
||||
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached"
|
||||
],
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -349,8 +349,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
skip_prefixes=([
|
||||
"rotary_emb.inv_freq", "lm_head.weight"
|
||||
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
|
||||
skip_prefixes=(["lm_head.weight"]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -80,18 +80,30 @@ class AutoWeightsLoader:
|
||||
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
|
||||
"""
|
||||
|
||||
# Models trained using early version ColossalAI
|
||||
# may include these tensors in checkpoint. Skip them.
|
||||
ROTARY_EMBEDS_UNUSED_WEIGHTS = [
|
||||
"rotary_emb.inv_freq",
|
||||
"rotary_emb.cos_cached",
|
||||
"rotary_emb.sin_cached",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
*,
|
||||
skip_prefixes: Optional[list[str]] = None,
|
||||
skip_substrs: Optional[list[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.skip_prefixes = skip_prefixes or []
|
||||
self.skip_substrs = skip_substrs or []
|
||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||
# update default skip_substrs
|
||||
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
|
||||
|
||||
def _groupby_prefix(
|
||||
self,
|
||||
@ -119,7 +131,8 @@ class AutoWeightsLoader:
|
||||
return ".".join((prefix, rest))
|
||||
|
||||
def _can_skip(self, qualname: str) -> bool:
|
||||
return any(qualname.startswith(p) for p in self.skip_prefixes)
|
||||
return (any(qualname.startswith(p) for p in self.skip_prefixes)
|
||||
or any(substr in qualname for substr in self.skip_substrs))
|
||||
|
||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||
return any(
|
||||
@ -257,6 +270,9 @@ class AutoWeightsLoader:
|
||||
) -> set[str]:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
# filter out weights with first-prefix/substr to skip in name
|
||||
weights = ((name, weight) for name, weight in weights
|
||||
if not self._can_skip(name))
|
||||
|
||||
autoloaded_weights = set(self._load_module("", self.module, weights))
|
||||
return autoloaded_weights
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user