Enable RMSNorm substitution for Transformers backend (#26353)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-10-09 08:28:51 +01:00 committed by GitHub
parent 1317028aa8
commit b960441812
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,7 +43,7 @@ from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
kwargs = {
"hidden_size": hidden_size,
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
"has_weight": getattr(rms_norm, "with_scale", True),
}
if (weight := getattr(rms_norm, "weight", None)) is not None:
# If weight is a Parameter, get its data tensor
weight = getattr(weight, "data", weight)
kwargs["dtype"] = weight.dtype
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
kwargs = {"hidden_size": hidden_size, "eps": eps}
# Update hidden size if weight is available
weight_meta = getattr(rms_norm, "weight", None)
if weight_meta is not None:
kwargs["hidden_size"] = weight_meta.size(0)
# Check if weight is all zeros, which indicates GemmaRMSNorm
# We must create a new instance because rms_norm is on meta
try:
with torch.device("cpu"):
weight_test = getattr(rms_norm.__class__(1), "weight", None)
except Exception:
logger.warning(
"Failed to determine if RMSNorm weight is centered on zero or one. "
"Defaulting to one."
)
weight_test = None
if weight_test is not None and torch.all(weight_test == 0):
return GemmaRMSNorm(**kwargs)
# Otherwise assume it's a regular RMSNorm
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
if weight_meta is not None:
kwargs["dtype"] = weight_meta.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
@ -645,11 +659,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
new_module = replace_linear_class(
child_module, style, self.quant_config, prefix=qual_name
)
# TODO(hmellor): Enable RMSNorm replacement once we have a way
# to choose RMSNorm vs GemmaRMSNorm
# elif child_module.__class__.__name__.endswith("RMSNorm"):
# new_module = replace_rms_norm_class(
# child_module, self.config.hidden_size)
elif child_module.__class__.__name__.endswith("RMSNorm"):
new_module = replace_rms_norm_class(
child_module, self.text_config.hidden_size
)
else:
_recursive_replace(child_module, prefix=qual_name)