mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name (#18358)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
d565e0976f
commit
f07a673eb2
@ -77,3 +77,73 @@ def test_module_with_child_containing_batchnorm_can_autoload():
|
|||||||
assert torch.all(
|
assert torch.all(
|
||||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_skip_prefix():
|
||||||
|
"""Ensure the auto weight loader can skip prefix."""
|
||||||
|
mod = ModuleWithNestedBatchNorm()
|
||||||
|
# Run some data through the module with batchnorm
|
||||||
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||||
|
|
||||||
|
# Try to load the weights to a new instance
|
||||||
|
def weight_generator():
|
||||||
|
# weights needed to be filtered out
|
||||||
|
redundant_weights = {
|
||||||
|
"prefix.bn.weight": torch.Tensor([1, 2]),
|
||||||
|
"prefix.bn.bias": torch.Tensor([3, 4]),
|
||||||
|
}
|
||||||
|
yield from (mod.state_dict() | redundant_weights).items()
|
||||||
|
|
||||||
|
new_mod = ModuleWithNestedBatchNorm()
|
||||||
|
|
||||||
|
assert not torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||||
|
assert not torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
|
||||||
|
loader.load_weights(weight_generator())
|
||||||
|
|
||||||
|
# Ensure the stats are updated
|
||||||
|
assert torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||||
|
assert torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_skip_substr():
|
||||||
|
"""Ensure the auto weight loader can skip prefix."""
|
||||||
|
mod = ModuleWithNestedBatchNorm()
|
||||||
|
# Run some data through the module with batchnorm
|
||||||
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||||||
|
|
||||||
|
# Try to load the weights to a new instance
|
||||||
|
def weight_generator():
|
||||||
|
# weights needed to be filtered out
|
||||||
|
redundant_weights = {
|
||||||
|
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
|
||||||
|
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
|
||||||
|
"nested_mod.substr.weight": torch.Tensor([1, 2]),
|
||||||
|
"nested_mod.substr.bias": torch.Tensor([3, 4]),
|
||||||
|
}
|
||||||
|
yield from (mod.state_dict() | redundant_weights).items()
|
||||||
|
|
||||||
|
new_mod = ModuleWithNestedBatchNorm()
|
||||||
|
|
||||||
|
assert not torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||||
|
assert not torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
|
||||||
|
loader.load_weights(weight_generator())
|
||||||
|
|
||||||
|
# Ensure the stats are updated
|
||||||
|
assert torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||||
|
assert torch.all(
|
||||||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||||
|
|||||||
@ -478,18 +478,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
skip_prefixes = [
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached",
|
|
||||||
]
|
|
||||||
# With tie_word_embeddings, we can skip lm_head.weight
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
# The weight might appear unnecessarily in the files if the model is
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
if self.config.tie_word_embeddings:
|
skip_prefixes = (["lm_head."]
|
||||||
skip_prefixes.append("lm_head.weight")
|
if self.config.tie_word_embeddings else None)
|
||||||
|
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=skip_prefixes,
|
||||||
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -550,10 +550,12 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
skip_prefixes = ["rotary_emb.inv_freq"]
|
|
||||||
# Skip lm_head when tie_word_embeddings is True
|
# Skip lm_head when tie_word_embeddings is True
|
||||||
if self.config.tie_word_embeddings:
|
skip_prefixes = (["lm_head"]
|
||||||
skip_prefixes.append("lm_head")
|
if self.config.tie_word_embeddings else None)
|
||||||
|
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=skip_prefixes,
|
||||||
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -482,5 +482,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -447,8 +447,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -502,14 +502,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=([
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -382,19 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=([
|
skip_prefixes=(["lm_head.weight"]
|
||||||
"rotary_emb.inv_freq",
|
if self.config.tie_word_embeddings else None),
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached",
|
|
||||||
"lm_head.weight"
|
|
||||||
] if self.config.tie_word_embeddings else [
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
]),
|
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -403,19 +403,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=([
|
skip_prefixes=(["lm_head.weight"]
|
||||||
"rotary_emb.inv_freq",
|
if self.config.tie_word_embeddings else None),
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached",
|
|
||||||
"lm_head.weight"
|
|
||||||
] if self.config.tie_word_embeddings else [
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
]),
|
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -442,8 +442,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=["rotary_emb.inv_freq"],
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -344,14 +344,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=([
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -1228,9 +1228,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> None:
|
torch.Tensor]]) -> None:
|
||||||
weights = ((name, data) for name, data in weights
|
loader = AutoWeightsLoader(self, skip_substrs=["lora"])
|
||||||
if "lora" not in name)
|
|
||||||
loader = AutoWeightsLoader(self)
|
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
|||||||
@ -660,8 +660,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -535,8 +535,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -530,8 +530,5 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=(["rotary_emb.inv_freq"]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -500,14 +500,5 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
skip_prefixes=([
|
|
||||||
"rotary_emb.inv_freq",
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
"rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -338,13 +338,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(self)
|
||||||
self,
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
skip_prefixes=[
|
|
||||||
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
|
|
||||||
"rotary_emb.sin_cached"
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -349,8 +349,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
skip_prefixes=([
|
skip_prefixes=(["lm_head.weight"]
|
||||||
"rotary_emb.inv_freq", "lm_head.weight"
|
if self.config.tie_word_embeddings else None),
|
||||||
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
|
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|||||||
@ -80,18 +80,30 @@ class AutoWeightsLoader:
|
|||||||
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
|
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Models trained using early version ColossalAI
|
||||||
|
# may include these tensors in checkpoint. Skip them.
|
||||||
|
ROTARY_EMBEDS_UNUSED_WEIGHTS = [
|
||||||
|
"rotary_emb.inv_freq",
|
||||||
|
"rotary_emb.cos_cached",
|
||||||
|
"rotary_emb.sin_cached",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
*,
|
*,
|
||||||
skip_prefixes: Optional[list[str]] = None,
|
skip_prefixes: Optional[list[str]] = None,
|
||||||
|
skip_substrs: Optional[list[str]] = None,
|
||||||
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.module = module
|
self.module = module
|
||||||
self.skip_prefixes = skip_prefixes or []
|
self.skip_prefixes = skip_prefixes or []
|
||||||
|
self.skip_substrs = skip_substrs or []
|
||||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||||
|
# update default skip_substrs
|
||||||
|
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
|
||||||
|
|
||||||
def _groupby_prefix(
|
def _groupby_prefix(
|
||||||
self,
|
self,
|
||||||
@ -119,7 +131,8 @@ class AutoWeightsLoader:
|
|||||||
return ".".join((prefix, rest))
|
return ".".join((prefix, rest))
|
||||||
|
|
||||||
def _can_skip(self, qualname: str) -> bool:
|
def _can_skip(self, qualname: str) -> bool:
|
||||||
return any(qualname.startswith(p) for p in self.skip_prefixes)
|
return (any(qualname.startswith(p) for p in self.skip_prefixes)
|
||||||
|
or any(substr in qualname for substr in self.skip_substrs))
|
||||||
|
|
||||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||||
return any(
|
return any(
|
||||||
@ -257,6 +270,9 @@ class AutoWeightsLoader:
|
|||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
if mapper is not None:
|
if mapper is not None:
|
||||||
weights = mapper.apply(weights)
|
weights = mapper.apply(weights)
|
||||||
|
# filter out weights with first-prefix/substr to skip in name
|
||||||
|
weights = ((name, weight) for name, weight in weights
|
||||||
|
if not self._can_skip(name))
|
||||||
|
|
||||||
autoloaded_weights = set(self._load_module("", self.module, weights))
|
autoloaded_weights = set(self._load_module("", self.module, weights))
|
||||||
return autoloaded_weights
|
return autoloaded_weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user