mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
Enable RMSNorm substitution for Transformers backend (#26353)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
1317028aa8
commit
b960441812
@ -43,7 +43,7 @@ from vllm.config.utils import getattr_iter
|
|||||||
from vllm.distributed import get_pp_group, get_tp_group
|
from vllm.distributed import get_pp_group, get_tp_group
|
||||||
from vllm.distributed.utils import get_pp_indices
|
from vllm.distributed.utils import get_pp_indices
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
|||||||
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
||||||
and Transformers doesn't appear to have the same concept.
|
and Transformers doesn't appear to have the same concept.
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
|
||||||
"hidden_size": hidden_size,
|
kwargs = {"hidden_size": hidden_size, "eps": eps}
|
||||||
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
|
# Update hidden size if weight is available
|
||||||
"has_weight": getattr(rms_norm, "with_scale", True),
|
weight_meta = getattr(rms_norm, "weight", None)
|
||||||
}
|
if weight_meta is not None:
|
||||||
if (weight := getattr(rms_norm, "weight", None)) is not None:
|
kwargs["hidden_size"] = weight_meta.size(0)
|
||||||
# If weight is a Parameter, get its data tensor
|
# Check if weight is all zeros, which indicates GemmaRMSNorm
|
||||||
weight = getattr(weight, "data", weight)
|
# We must create a new instance because rms_norm is on meta
|
||||||
kwargs["dtype"] = weight.dtype
|
try:
|
||||||
|
with torch.device("cpu"):
|
||||||
|
weight_test = getattr(rms_norm.__class__(1), "weight", None)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to determine if RMSNorm weight is centered on zero or one. "
|
||||||
|
"Defaulting to one."
|
||||||
|
)
|
||||||
|
weight_test = None
|
||||||
|
if weight_test is not None and torch.all(weight_test == 0):
|
||||||
|
return GemmaRMSNorm(**kwargs)
|
||||||
|
# Otherwise assume it's a regular RMSNorm
|
||||||
|
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
|
||||||
|
if weight_meta is not None:
|
||||||
|
kwargs["dtype"] = weight_meta.dtype
|
||||||
else:
|
else:
|
||||||
# No weight, fall back to weightless RMSNorm
|
# No weight, fall back to weightless RMSNorm
|
||||||
kwargs["has_weight"] = False
|
kwargs["has_weight"] = False
|
||||||
@ -645,11 +659,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
new_module = replace_linear_class(
|
new_module = replace_linear_class(
|
||||||
child_module, style, self.quant_config, prefix=qual_name
|
child_module, style, self.quant_config, prefix=qual_name
|
||||||
)
|
)
|
||||||
# TODO(hmellor): Enable RMSNorm replacement once we have a way
|
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||||
# to choose RMSNorm vs GemmaRMSNorm
|
new_module = replace_rms_norm_class(
|
||||||
# elif child_module.__class__.__name__.endswith("RMSNorm"):
|
child_module, self.text_config.hidden_size
|
||||||
# new_module = replace_rms_norm_class(
|
)
|
||||||
# child_module, self.config.hidden_size)
|
|
||||||
else:
|
else:
|
||||||
_recursive_replace(child_module, prefix=qual_name)
|
_recursive_replace(child_module, prefix=qual_name)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user