mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
Enable RMSNorm substitution for Transformers backend (#26353)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
1317028aa8
commit
b960441812
@ -43,7 +43,7 @@ from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
||||
and Transformers doesn't appear to have the same concept.
|
||||
"""
|
||||
kwargs = {
|
||||
"hidden_size": hidden_size,
|
||||
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
|
||||
"has_weight": getattr(rms_norm, "with_scale", True),
|
||||
}
|
||||
if (weight := getattr(rms_norm, "weight", None)) is not None:
|
||||
# If weight is a Parameter, get its data tensor
|
||||
weight = getattr(weight, "data", weight)
|
||||
kwargs["dtype"] = weight.dtype
|
||||
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
|
||||
kwargs = {"hidden_size": hidden_size, "eps": eps}
|
||||
# Update hidden size if weight is available
|
||||
weight_meta = getattr(rms_norm, "weight", None)
|
||||
if weight_meta is not None:
|
||||
kwargs["hidden_size"] = weight_meta.size(0)
|
||||
# Check if weight is all zeros, which indicates GemmaRMSNorm
|
||||
# We must create a new instance because rms_norm is on meta
|
||||
try:
|
||||
with torch.device("cpu"):
|
||||
weight_test = getattr(rms_norm.__class__(1), "weight", None)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to determine if RMSNorm weight is centered on zero or one. "
|
||||
"Defaulting to one."
|
||||
)
|
||||
weight_test = None
|
||||
if weight_test is not None and torch.all(weight_test == 0):
|
||||
return GemmaRMSNorm(**kwargs)
|
||||
# Otherwise assume it's a regular RMSNorm
|
||||
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
|
||||
if weight_meta is not None:
|
||||
kwargs["dtype"] = weight_meta.dtype
|
||||
else:
|
||||
# No weight, fall back to weightless RMSNorm
|
||||
kwargs["has_weight"] = False
|
||||
@ -645,11 +659,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
new_module = replace_linear_class(
|
||||
child_module, style, self.quant_config, prefix=qual_name
|
||||
)
|
||||
# TODO(hmellor): Enable RMSNorm replacement once we have a way
|
||||
# to choose RMSNorm vs GemmaRMSNorm
|
||||
# elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||
# new_module = replace_rms_norm_class(
|
||||
# child_module, self.config.hidden_size)
|
||||
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||
new_module = replace_rms_norm_class(
|
||||
child_module, self.text_config.hidden_size
|
||||
)
|
||||
else:
|
||||
_recursive_replace(child_module, prefix=qual_name)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user