mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:56:33 +08:00
FusedMoE support for the Transformers backend (#22650)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
bbeace233b
commit
6b12b2ee38
@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models]
|
||||
|
||||
### Transformers
|
||||
|
||||
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
|
||||
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
|
||||
|
||||
Currently, the Transformers backend works for the following:
|
||||
|
||||
- Modalities: embedding models, language models and vision-language models*
|
||||
- Architectures: encoder-only, decoder-only
|
||||
- Architectures: encoder-only, decoder-only, mixture-of-experts
|
||||
- Attention types: full attention and/or sliding attention
|
||||
|
||||
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
|
||||
@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus
|
||||
|
||||
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
|
||||
- Any combination of the following vLLM parallelisation schemes:
|
||||
- Data parallel
|
||||
- Pipeline parallel
|
||||
- Tensor parallel
|
||||
|
||||
|
||||
@ -661,6 +661,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||
}
|
||||
|
||||
_EXAMPLE_MODELS = {
|
||||
|
||||
@ -66,6 +66,7 @@ def check_implementation(
|
||||
[
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||
("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE
|
||||
]) # trust_remote_code=True by default
|
||||
def test_models(
|
||||
hf_runner: type[HfRunner],
|
||||
@ -74,6 +75,14 @@ def test_models(
|
||||
model: str,
|
||||
model_impl: str,
|
||||
) -> None:
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
|
||||
pytest.skip("MoE models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
|
||||
check_implementation(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
|
||||
@ -430,17 +430,26 @@ def dummy_hf_overrides(
|
||||
|
||||
update_dict = {
|
||||
"num_layers": num_layers,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
}
|
||||
|
||||
class DummyConfig:
|
||||
hf_text_config = text_config
|
||||
|
||||
# Only set MoE related config when the model has MoE layers.
|
||||
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
|
||||
if ModelConfig.get_num_experts(DummyConfig) > 0:
|
||||
update_dict.update({
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
})
|
||||
|
||||
# Update num_hidden_layers for non-Longcat architectures
|
||||
if model_arch != "LongcatFlashForCausalLM" \
|
||||
and model_arch != "LongCatFlashMTPModel":
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
@ -667,6 +667,8 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
prefix = "Transformers"
|
||||
prefix += "MoE" if self.get_num_experts() > 1 else ""
|
||||
# Check if the architecture we're wrapping has defaults
|
||||
runner = None
|
||||
convert = None
|
||||
@ -685,15 +687,15 @@ class ModelConfig:
|
||||
# Resolve Transformers backend pooling classes
|
||||
if runner == "pooling":
|
||||
if convert == "embed":
|
||||
return "TransformersEmbeddingModel"
|
||||
return prefix + "EmbeddingModel"
|
||||
if convert == "classify":
|
||||
return "TransformersForSequenceClassification"
|
||||
return prefix + "ForSequenceClassification"
|
||||
# Resolve Transformers backend generate classes
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
return "TransformersForMultimodalLM"
|
||||
return "TransformersForCausalLM"
|
||||
return prefix + "ForMultimodalLM"
|
||||
return prefix + "ForCausalLM"
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
@ -1025,17 +1027,7 @@ class ModelConfig:
|
||||
self.enforce_eager = True
|
||||
|
||||
def _verify_with_expert_parallelism(self) -> None:
|
||||
num_expert_names = [
|
||||
"moe_num_experts", # Dbrx
|
||||
"num_experts", # Jamba
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = 0
|
||||
for name in num_expert_names:
|
||||
num_experts = getattr(self.hf_text_config, name, 0)
|
||||
if num_experts > 0:
|
||||
break
|
||||
num_experts = self.get_num_experts()
|
||||
if num_experts < 1:
|
||||
raise ValueError(
|
||||
"Number of experts in the model must be greater than 0 "
|
||||
@ -1220,6 +1212,21 @@ class ModelConfig:
|
||||
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_experts(self) -> int:
|
||||
"""Returns the number of experts in the model."""
|
||||
num_expert_names = [
|
||||
"num_experts", # Jamba
|
||||
"moe_num_experts", # Dbrx
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
|
||||
if isinstance(num_experts, list):
|
||||
# Ernie VL's remote code uses list[int]...
|
||||
# The values are always the same so we just take the first one.
|
||||
return num_experts[0]
|
||||
return num_experts
|
||||
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: ParallelConfig) -> tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
|
||||
@ -960,6 +960,7 @@ class FusedMoE(CustomOp):
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@ -996,6 +997,9 @@ class FusedMoE(CustomOp):
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
|
||||
# Expert mapping used in self.load_weights
|
||||
self.expert_mapping = expert_mapping
|
||||
|
||||
# Round up hidden size if needed.
|
||||
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
|
||||
quant_config,
|
||||
@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp):
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> Iterable[str]:
|
||||
if (expert_mapping := self.expert_mapping) is None:
|
||||
raise ValueError("`self.expert_mapping` must be provided to "
|
||||
"load weights using `self.load_weights`.")
|
||||
for expert_name, loaded_weight in weights:
|
||||
qual_name = f"{self.layer_name}.{expert_name}"
|
||||
for param_name, weight_name, expert_id, shard_id in expert_mapping:
|
||||
if weight_name not in qual_name:
|
||||
continue
|
||||
weight_name = qual_name.replace(weight_name, param_name)
|
||||
param_name = weight_name.removeprefix(f"{self.layer_name}.")
|
||||
param = getattr(self, param_name)
|
||||
success = self.weight_loader(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
logger.debug("Loaded %s for expert %d into %s", param_name,
|
||||
expert_id, self.layer_name)
|
||||
yield param_name
|
||||
|
||||
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
||||
weights = list(self.named_parameters())
|
||||
assert all(weight.is_contiguous() for _, weight in weights)
|
||||
|
||||
@ -307,10 +307,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
|
||||
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
|
||||
"TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501
|
||||
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
|
||||
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
|
||||
"TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501
|
||||
"TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@ -22,6 +22,8 @@ from typing import Literal, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
from torch import nn
|
||||
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
|
||||
PreTrainedModel)
|
||||
@ -35,6 +37,7 @@ from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@ -121,10 +124,14 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
|
||||
return enable
|
||||
|
||||
|
||||
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep",
|
||||
"replicate"]
|
||||
|
||||
|
||||
def replace_linear_class(
|
||||
linear: nn.Linear,
|
||||
style: Literal["colwise", "rowwise"],
|
||||
quant_config: QuantizationConfig,
|
||||
style: Style = "replicate",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
||||
@ -132,11 +139,11 @@ def replace_linear_class(
|
||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||
|
||||
Args:
|
||||
linear (nn.Linear): `nn.Linear` to be replaced.
|
||||
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
||||
quant_config (QuantConfig): Quantization config for the new linear.
|
||||
linear: `nn.Linear` to be replaced.
|
||||
style: Tensor parallel style of the new linear, e.g. "colwise".
|
||||
quant_config: Quantization config for the new linear.
|
||||
Returns:
|
||||
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
|
||||
The new linear.
|
||||
"""
|
||||
|
||||
if not isinstance(style, str):
|
||||
@ -166,6 +173,31 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
|
||||
|
||||
This method assumes:
|
||||
- Weight is stored as `weight`.
|
||||
- Epsilon is stored as `eps` or `variance_epsilon`.
|
||||
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
|
||||
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
||||
and Transformers doesn't appear to have the same concept.
|
||||
"""
|
||||
kwargs = {
|
||||
"hidden_size": hidden_size,
|
||||
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
|
||||
"has_weight": getattr(rms_norm, "with_scale", True)
|
||||
}
|
||||
if (weight := getattr(rms_norm, "weight", None)) is not None:
|
||||
# If weight is a Parameter, get its data tensor
|
||||
weight = getattr(weight, "data", weight)
|
||||
kwargs["dtype"] = weight.dtype
|
||||
else:
|
||||
# No weight, fall back to weightless RMSNorm
|
||||
kwargs["has_weight"] = False
|
||||
return RMSNorm(**kwargs)
|
||||
|
||||
|
||||
# Copied from `accelerate`
|
||||
@contextmanager
|
||||
def init_on_device_without_buffers(device: torch.device):
|
||||
@ -463,9 +495,15 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.ignore_unexpected_suffixes: list[str] = []
|
||||
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if self.quant_config and "gptq" in self.quant_config.get_name():
|
||||
self.ignore_unexpected_suffixes.append(".bias")
|
||||
if self.quant_config:
|
||||
quant_method_name = self.quant_config.get_name()
|
||||
# Check for unsupported quantization methods.
|
||||
if quant_method_name == "mxfp4":
|
||||
raise NotImplementedError("Transformers backend does not "
|
||||
"support MXFP4 quantization yet.")
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if "gptq" in quant_method_name:
|
||||
self.ignore_unexpected_suffixes.append(".bias")
|
||||
|
||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||
@ -478,8 +516,12 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Remove layers not on this pipeline parallel rank
|
||||
self.pipeline_parallel()
|
||||
self.tensor_parallel()
|
||||
# Substitute remaining layers with vLLM's layers as needed
|
||||
self.recursive_replace()
|
||||
# Create attention instances for KV cache allocation
|
||||
self.attention_instances = self.create_attention_instances()
|
||||
|
||||
# Input embeddings
|
||||
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
||||
@ -494,12 +536,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
quant_config=self.quant_config,
|
||||
))
|
||||
|
||||
# Attention layers
|
||||
self.attention_instances = self.create_attention_instances()
|
||||
|
||||
# Initialize any parameters that have not had their modules replaced
|
||||
self.init_parameters(self.model)
|
||||
|
||||
# Pipeline parallel intermediate tensors
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states"], self.text_config.hidden_size))
|
||||
@ -558,56 +598,53 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
if not self.pp_group.is_last_rank:
|
||||
setattr(self.model, name, PPMissingLayer())
|
||||
|
||||
def tensor_parallel(self):
|
||||
"""
|
||||
Apply the model's tensor parallelization plan.
|
||||
Currently only supports linear layers.
|
||||
"""
|
||||
# Look for tp plans in all of the PreTrainedModels found in self.model
|
||||
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
|
||||
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
|
||||
pretrained_models = filter(is_pretrained_model, self.model.modules())
|
||||
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
|
||||
def recursive_replace(self):
|
||||
"""Recursively replace modules in the model as needed.
|
||||
|
||||
if not any(models_with_tp_plan) and self.tp_size > 1:
|
||||
Currently, this replaces:
|
||||
|
||||
- `nn.Linear` with vLLM's tensor parallel linear classes
|
||||
- `*RMSNorm` with vLLM's `RMSNorm`
|
||||
"""
|
||||
tp_plan = self.model.tp_plan
|
||||
|
||||
if not tp_plan and self.tp_size > 1:
|
||||
tip = get_feature_request_tip(self.model_config.model,
|
||||
self.model_config.trust_remote_code)
|
||||
raise ValueError(
|
||||
f"{type(self.model)} does not support tensor parallel. {tip}")
|
||||
|
||||
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
|
||||
tp_plan = tp_plan or {}
|
||||
# Prefix the patterns because we always start from `self.model`
|
||||
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
|
||||
|
||||
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||
# all of its children
|
||||
if isinstance(module, PreTrainedModel):
|
||||
tp_plan = module.config.base_model_tp_plan or {}
|
||||
tp_plan = {
|
||||
maybe_prefix(prefix, k): v
|
||||
for k, v in tp_plan.items()
|
||||
}
|
||||
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
def _recursive_replace(module: nn.Module, prefix: str):
|
||||
for child_name, child_module in module.named_children():
|
||||
new_module = child_module
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
if isinstance(child_module, nn.Linear):
|
||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||
pattern = next(generator, None)
|
||||
# Some weight loaders expect all linear layers to inherit
|
||||
# LinearBase, so we set a default style which causes any
|
||||
# unspecified layers to be replaced with ReplicatedLinear
|
||||
style = tp_plan.get(pattern, "replicate")
|
||||
new_module = replace_linear_class(child_module,
|
||||
style,
|
||||
self.quant_config,
|
||||
prefix=qual_name)
|
||||
# TODO(hmellor): Enable RMSNorm replacement once we have a way
|
||||
# to choose RMSNorm vs GemmaRMSNorm
|
||||
# elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||
# new_module = replace_rms_norm_class(
|
||||
# child_module, self.config.hidden_size)
|
||||
else:
|
||||
_recursive_replace(child_module, prefix=qual_name)
|
||||
|
||||
if new_module is not child_module:
|
||||
setattr(module, child_name, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
else:
|
||||
_tensor_parallel(child_module,
|
||||
prefix=qual_name,
|
||||
tp_plan=tp_plan)
|
||||
|
||||
_tensor_parallel(self.model, prefix="model")
|
||||
_recursive_replace(self.model, prefix="model")
|
||||
|
||||
def create_attention_instances(
|
||||
self,
|
||||
@ -657,15 +694,21 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(...)
|
||||
```
|
||||
"""
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device == torch.device("meta"):
|
||||
new_param = nn.Parameter(
|
||||
torch.empty_like(param.data,
|
||||
dtype=dtype or self.model_config.dtype,
|
||||
device=self.device_config.device))
|
||||
setattr(module, name, new_param)
|
||||
for child in module.children():
|
||||
self.init_parameters(child, dtype)
|
||||
|
||||
def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device == torch.device("meta"):
|
||||
new_param = nn.Parameter(
|
||||
torch.empty_like(
|
||||
param.data,
|
||||
dtype=dtype or self.model_config.dtype,
|
||||
device=self.device_config.device,
|
||||
))
|
||||
setattr(module, name, new_param)
|
||||
for child in module.children():
|
||||
_init_parameters(child, dtype)
|
||||
|
||||
_init_parameters(module, dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -702,8 +745,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=self.skip_prefixes,
|
||||
@ -713,6 +758,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def check_version(self, min_version: str, feature: str):
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version(min_version)
|
||||
if installed < required:
|
||||
raise ImportError(
|
||||
f"Transformers backend requires transformers>={required} "
|
||||
f"for {feature}, but got {installed}")
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForCausalLM(TransformersBase):
|
||||
|
||||
268
vllm/model_executor/models/transformers_moe.py
Normal file
268
vllm/model_executor/models/transformers_moe.py
Normal file
@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2024 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` MoE models."""
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .transformers import (TransformersBase, TransformersForCausalLM,
|
||||
TransformersForMultimodalLM,
|
||||
can_enable_torch_compile, log_replacement)
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
@CustomOp.register("transformers_fused_moe")
|
||||
class TransformersFusedMoE(FusedMoE):
|
||||
"""Custom FusedMoE for the Transformers backend."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._top_k_index: torch.Tensor = None
|
||||
|
||||
def custom_routing_function(hidden_states, gating_output, topk,
|
||||
renormalize):
|
||||
"""Return `top_k_weights` from `gating_output` and the
|
||||
`top_k_index` we stored in the layer earlier."""
|
||||
return gating_output, self._top_k_index
|
||||
|
||||
self.custom_routing_function = custom_routing_function
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""In Transformers `experts.forward` will have this signature.
|
||||
|
||||
We discard any extra kwargs because we cannot use them here."""
|
||||
return torch.ops.vllm.transformers_moe_forward(hidden_states,
|
||||
top_k_index,
|
||||
top_k_weights,
|
||||
self.layer_name)
|
||||
|
||||
|
||||
def transformers_moe_forward(hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
"""Store the `top_k_index` in the layer and call the actual forward."""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._top_k_index = top_k_index
|
||||
# Clone hidden_states because it will be mutated in-place in FusedMoE
|
||||
return self.forward_impl(hidden_states.clone(), top_k_weights)
|
||||
|
||||
|
||||
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="transformers_moe_forward",
|
||||
op_func=transformers_moe_forward,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=transformers_moe_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class TransformersMoEBase(TransformersBase):
|
||||
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
self.check_version("4.57.0.dev0", "MoE models support")
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
raise NotImplementedError(
|
||||
"Transformers backend does not support expert parallel yet.")
|
||||
if self.parallel_config.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"Transformers backend does not support expert parallel load "
|
||||
"balancing yet.")
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
"""
|
||||
Params for weights, fp8 weight scales, fp8 activation scales
|
||||
(param_name, weight_name, expert_id, shard_id)
|
||||
"""
|
||||
ckpt_names = [
|
||||
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
|
||||
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
|
||||
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
|
||||
("linear", "linear_1", "linear_v"), # Grok1 style
|
||||
]
|
||||
expert_mapping = []
|
||||
for gate_proj, down_proj, up_proj in ckpt_names:
|
||||
expert_mapping.extend(
|
||||
FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name=gate_proj,
|
||||
ckpt_down_proj_name=down_proj,
|
||||
ckpt_up_proj_name=up_proj,
|
||||
num_experts=self.model_config.get_num_experts(),
|
||||
num_redundant_experts=0, # TODO: enable EPLB
|
||||
))
|
||||
return expert_mapping
|
||||
|
||||
def recursive_replace(self):
|
||||
"""Initialize the MoE layers."""
|
||||
text_config = self.text_config
|
||||
|
||||
# Positional arguments
|
||||
num_experts = self.model_config.get_num_experts()
|
||||
top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"],
|
||||
None)
|
||||
assert top_k is not None
|
||||
hidden_size = text_config.hidden_size
|
||||
intermediate_size = getattr_iter(
|
||||
text_config, ["moe_intermediate_size", "intermediate_size"], None)
|
||||
assert intermediate_size is not None
|
||||
|
||||
# If there are shared experts, the results are
|
||||
# reduced after mlp.forward() not inside FusedMoE
|
||||
num_experts_shared = getattr_iter(text_config, [
|
||||
"num_experts_shared", "n_shared_experts", "moe_num_shared_experts"
|
||||
], 0)
|
||||
reduce_results = num_experts_shared == 0
|
||||
|
||||
def add_all_reduce(mlp: nn.Module):
|
||||
"""Adds an all-reduce to the output of `mlp.forward()`."""
|
||||
|
||||
class MLPWithAllReduce(mlp.__class__):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
output = super().forward(*args, **kwargs)
|
||||
return self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
output)
|
||||
|
||||
mlp.__class__ = MLPWithAllReduce
|
||||
|
||||
# Unused kwargs since we use custom_routing_function:
|
||||
# - `scoring_func` and `e_score_correction_bias` only used for grouped
|
||||
# topk routing inside vLLM and are non-trivial to infer
|
||||
# and hard code `use_grouped_topk=False`
|
||||
# - `renormalize` passed anyway because it's easy to infer
|
||||
# - `num_expert_group` and `topk_group` used for inferring expert
|
||||
# placement strategy in FusedMoE
|
||||
# - `apply_router_weight_on_input` is already applied in Transformers
|
||||
renormalize = getattr(text_config, "norm_topk_prob", top_k > 1)
|
||||
num_expert_group = getattr(text_config, "n_group", None)
|
||||
topk_group = getattr(text_config, "topk_group", None)
|
||||
|
||||
# MoE activation function
|
||||
activation = "silu"
|
||||
wrapped_arch = self.config.architectures[0].lower()
|
||||
if "gptoss" in wrapped_arch:
|
||||
activation = "swigluoai"
|
||||
elif "grok1" in wrapped_arch:
|
||||
activation = "gelu"
|
||||
|
||||
# Expert mapping for `AutoWeightsLoader`
|
||||
expert_mapping = self.get_expert_mapping()
|
||||
|
||||
# Configs
|
||||
parallel_config = self.parallel_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
|
||||
# Expert parallel load balancing kwargs
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
# Recursively fuse MoE layers
|
||||
def _recursive_replace(module: nn.Module, prefix: str):
|
||||
for child_name, child_module in module.named_children():
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
if (child_name == "experts"
|
||||
and isinstance(child_module, nn.ModuleList)):
|
||||
# Alias for readability
|
||||
mlp = module
|
||||
experts = child_module
|
||||
# Do the experts have biases
|
||||
has_bias = False
|
||||
for experts_param_name, _ in experts.named_parameters():
|
||||
if "bias" in experts_param_name:
|
||||
has_bias = True
|
||||
break
|
||||
# Double check there are no shared experts
|
||||
nonlocal reduce_results
|
||||
if reduce_results:
|
||||
for mlp_param_name, _ in mlp.named_parameters():
|
||||
if "shared_expert" in mlp_param_name:
|
||||
reduce_results = False
|
||||
break
|
||||
# Replace experts module with FusedMoE
|
||||
fused_experts = TransformersFusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
# Hard coded because topk happens in Transformers
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=self.quant_config,
|
||||
prefix=qual_name,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
has_bias=has_bias,
|
||||
expert_mapping=expert_mapping,
|
||||
)
|
||||
mlp.experts = fused_experts
|
||||
log_replacement(qual_name, experts, fused_experts)
|
||||
# If results are not all-reduced in FusedMoE, ensure they
|
||||
# are all-reduced at the end of mlp.forward() if tensor
|
||||
# parallel or expert parallel is enabled
|
||||
if not reduce_results and (fused_experts.tp_size > 1
|
||||
or fused_experts.ep_size > 1):
|
||||
add_all_reduce(mlp)
|
||||
else:
|
||||
_recursive_replace(child_module, prefix=qual_name)
|
||||
|
||||
_recursive_replace(self.model, prefix="model")
|
||||
# Continue with the replacement of layers in TransformersBase
|
||||
super().recursive_replace()
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
# set `positions` to last dim to support Qwen-mrope
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
},
|
||||
enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM,
|
||||
TransformersForMultimodalLM):
|
||||
pass
|
||||
@ -20,7 +20,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
@ -29,6 +29,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces_base import VllmModelForPooling
|
||||
from .transformers import TransformersBase, can_enable_torch_compile
|
||||
from .transformers_moe import TransformersMoEBase
|
||||
from .utils import WeightsMapper
|
||||
|
||||
|
||||
@ -79,7 +80,9 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
|
||||
self.padding_idx = self.text_config.pad_token_id
|
||||
|
||||
def create_attention_instances(
|
||||
self, attn_type: AttentionType = AttentionType.DECODER):
|
||||
self,
|
||||
attn_type: AttentionType = AttentionType.DECODER
|
||||
) -> dict[int, Attention]:
|
||||
# TODO(hmellor): Better way to detect encoder models
|
||||
# In encoder models, the attention layers will have `is_causal=False`
|
||||
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
||||
@ -90,14 +93,7 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
|
||||
|
||||
# Check minimum transformers version for encoder models support
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
raise ValueError(
|
||||
"Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
self.check_version("4.57.0.dev0", "encoder models support")
|
||||
|
||||
return super().create_attention_instances(attn_type)
|
||||
|
||||
@ -198,3 +194,15 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEEmbeddingModel(TransformersMoEBase,
|
||||
TransformersEmbeddingModel):
|
||||
pass
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEForSequenceClassification(
|
||||
TransformersMoEBase, TransformersForSequenceClassification):
|
||||
pass
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user