diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 47e829861284..1cfe401b243c 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -43,7 +43,7 @@ from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, @@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: - `var_hidden_size` is only ever used for Intern vision encoder in vLLM and Transformers doesn't appear to have the same concept. """ - kwargs = { - "hidden_size": hidden_size, - "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), - "has_weight": getattr(rms_norm, "with_scale", True), - } - if (weight := getattr(rms_norm, "weight", None)) is not None: - # If weight is a Parameter, get its data tensor - weight = getattr(weight, "data", weight) - kwargs["dtype"] = weight.dtype + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + # Update hidden size if weight is available + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + # Check if weight is all zeros, which indicates GemmaRMSNorm + # We must create a new instance because rms_norm is on meta + try: + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None + if weight_test is not None and torch.all(weight_test == 0): + return GemmaRMSNorm(**kwargs) + # Otherwise assume it's a regular RMSNorm + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["dtype"] = weight_meta.dtype else: # No weight, fall back to weightless RMSNorm kwargs["has_weight"] = False @@ -645,11 +659,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): new_module = replace_linear_class( child_module, style, self.quant_config, prefix=qual_name ) - # TODO(hmellor): Enable RMSNorm replacement once we have a way - # to choose RMSNorm vs GemmaRMSNorm - # elif child_module.__class__.__name__.endswith("RMSNorm"): - # new_module = replace_rms_norm_class( - # child_module, self.config.hidden_size) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, self.text_config.hidden_size + ) else: _recursive_replace(child_module, prefix=qual_name)