FusedMoE support for the Transformers backend (#22650)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Harry Mellor 2025-10-03 07:12:15 +01:00 committed by yewentao256
parent bbeace233b
commit 6b12b2ee38
10 changed files with 485 additions and 91 deletions

View File

@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
Currently, the Transformers backend works for the following:
- Modalities: embedding models, language models and vision-language models*
- Architectures: encoder-only, decoder-only
- Architectures: encoder-only, decoder-only, mixture-of-experts
- Attention types: full attention and/or sliding attention
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Data parallel
- Pipeline parallel
- Tensor parallel

View File

@ -661,6 +661,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
}
_EXAMPLE_MODELS = {

View File

@ -66,6 +66,7 @@ def check_implementation(
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE
]) # trust_remote_code=True by default
def test_models(
hf_runner: type[HfRunner],
@ -74,6 +75,14 @@ def test_models(
model: str,
model_impl: str,
) -> None:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip("MoE models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
check_implementation(hf_runner,
vllm_runner,
example_prompts,

View File

@ -430,17 +430,26 @@ def dummy_hf_overrides(
update_dict = {
"num_layers": num_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
}
class DummyConfig:
hf_text_config = text_config
# Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelConfig.get_num_experts(DummyConfig) > 0:
update_dict.update({
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
})
# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":

View File

@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
@ -667,6 +667,8 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
@ -685,15 +687,15 @@ class ModelConfig:
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
return prefix + "EmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
return "TransformersForCausalLM"
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
@ -1025,17 +1027,7 @@ class ModelConfig:
self.enforce_eager = True
def _verify_with_expert_parallelism(self) -> None:
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(self.hf_text_config, name, 0)
if num_experts > 0:
break
num_experts = self.get_num_experts()
if num_experts < 1:
raise ValueError(
"Number of experts in the model must be greater than 0 "
@ -1220,6 +1212,21 @@ class ModelConfig:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
return num_experts
def get_layers_start_end_indices(
self, parallel_config: ParallelConfig) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices

View File

@ -960,6 +960,7 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
):
super().__init__()
if params_dtype is None:
@ -996,6 +997,9 @@ class FusedMoE(CustomOp):
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
quant_config,
@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp):
return False if return_success else None
def load_weights(
self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Iterable[str]:
if (expert_mapping := self.expert_mapping) is None:
raise ValueError("`self.expert_mapping` must be provided to "
"load weights using `self.load_weights`.")
for expert_name, loaded_weight in weights:
qual_name = f"{self.layer_name}.{expert_name}"
for param_name, weight_name, expert_id, shard_id in expert_mapping:
if weight_name not in qual_name:
continue
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug("Loaded %s for expert %d into %s", param_name,
expert_id, self.layer_name)
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
assert all(weight.is_contiguous() for _, weight in weights)

View File

@ -307,10 +307,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
"TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501
"TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501
}
# yapf: enable

View File

@ -22,6 +22,8 @@ from typing import Literal, Optional, Union
import regex as re
import torch
import transformers
from packaging.version import Version
from torch import nn
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel)
@ -35,6 +37,7 @@ from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@ -121,10 +124,14 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
return enable
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep",
"replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig,
style: Style = "replicate",
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
@ -132,11 +139,11 @@ def replace_linear_class(
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
The new linear.
"""
if not isinstance(style, str):
@ -166,6 +173,31 @@ def replace_linear_class(
)
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
kwargs = {
"hidden_size": hidden_size,
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
"has_weight": getattr(rms_norm, "with_scale", True)
}
if (weight := getattr(rms_norm, "weight", None)) is not None:
# If weight is a Parameter, get its data tensor
weight = getattr(weight, "data", weight)
kwargs["dtype"] = weight.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
@ -463,9 +495,15 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
# Skip loading extra bias for GPTQ models.
if self.quant_config and "gptq" in self.quant_config.get_name():
self.ignore_unexpected_suffixes.append(".bias")
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
raise NotImplementedError("Transformers backend does not "
"support MXFP4 quantization yet.")
# Skip loading extra bias for GPTQ models.
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
@ -478,8 +516,12 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
trust_remote_code=self.model_config.trust_remote_code,
)
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
self.tensor_parallel()
# Substitute remaining layers with vLLM's layers as needed
self.recursive_replace()
# Create attention instances for KV cache allocation
self.attention_instances = self.create_attention_instances()
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
@ -494,12 +536,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
quant_config=self.quant_config,
))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size))
@ -558,56 +598,53 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
def tensor_parallel(self):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
# Look for tp plans in all of the PreTrainedModels found in self.model
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
pretrained_models = filter(is_pretrained_model, self.model.modules())
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
def recursive_replace(self):
"""Recursively replace modules in the model as needed.
if not any(models_with_tp_plan) and self.tp_size > 1:
Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
- `*RMSNorm` with vLLM's `RMSNorm`
"""
tp_plan = self.model.tp_plan
if not tp_plan and self.tp_size > 1:
tip = get_feature_request_tip(self.model_config.model,
self.model_config.trust_remote_code)
raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}")
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
tp_plan = tp_plan or {}
# Prefix the patterns because we always start from `self.model`
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
# If the current module is a PreTrainedModel, set the tp_plan for
# all of its children
if isinstance(module, PreTrainedModel):
tp_plan = module.config.base_model_tp_plan or {}
tp_plan = {
maybe_prefix(prefix, k): v
for k, v in tp_plan.items()
}
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
# Some weight loaders expect all linear layers to inherit
# LinearBase, so we set a default style which causes any
# unspecified layers to be replaced with ReplicatedLinear
style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(child_module,
style,
self.quant_config,
prefix=qual_name)
# TODO(hmellor): Enable RMSNorm replacement once we have a way
# to choose RMSNorm vs GemmaRMSNorm
# elif child_module.__class__.__name__.endswith("RMSNorm"):
# new_module = replace_rms_norm_class(
# child_module, self.config.hidden_size)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module,
prefix=qual_name,
tp_plan=tp_plan)
_tensor_parallel(self.model, prefix="model")
_recursive_replace(self.model, prefix="model")
def create_attention_instances(
self,
@ -657,15 +694,21 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
"""
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
self.init_parameters(child, dtype)
def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device,
))
setattr(module, name, new_param)
for child in module.children():
_init_parameters(child, dtype)
_init_parameters(module, dtype)
def forward(
self,
@ -702,8 +745,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
@ -713,6 +758,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def check_version(self, min_version: str, feature: str):
installed = Version(transformers.__version__)
required = Version(min_version)
if installed < required:
raise ImportError(
f"Transformers backend requires transformers>={required} "
f"for {feature}, but got {installed}")
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):

View File

@ -0,0 +1,268 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` MoE models."""
from typing import Any
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .transformers import (TransformersBase, TransformersForCausalLM,
TransformersForMultimodalLM,
can_enable_torch_compile, log_replacement)
from .utils import maybe_prefix
@CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers backend."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._top_k_index: torch.Tensor = None
def custom_routing_function(hidden_states, gating_output, topk,
renormalize):
"""Return `top_k_weights` from `gating_output` and the
`top_k_index` we stored in the layer earlier."""
return gating_output, self._top_k_index
self.custom_routing_function = custom_routing_function
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""In Transformers `experts.forward` will have this signature.
We discard any extra kwargs because we cannot use them here."""
return torch.ops.vllm.transformers_moe_forward(hidden_states,
top_k_index,
top_k_weights,
self.layer_name)
def transformers_moe_forward(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
"""Store the `top_k_index` in the layer and call the actual forward."""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._top_k_index = top_k_index
# Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), top_k_weights)
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="transformers_moe_forward",
op_func=transformers_moe_forward,
mutates_args=["hidden_states"],
fake_impl=transformers_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class TransformersMoEBase(TransformersBase):
def __init__(self, *, vllm_config, prefix=""):
self.check_version("4.57.0.dev0", "MoE models support")
super().__init__(vllm_config=vllm_config, prefix=prefix)
if self.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"Transformers backend does not support expert parallel yet.")
if self.parallel_config.enable_eplb:
raise NotImplementedError(
"Transformers backend does not support expert parallel load "
"balancing yet.")
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Params for weights, fp8 weight scales, fp8 activation scales
(param_name, weight_name, expert_id, shard_id)
"""
ckpt_names = [
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
("linear", "linear_1", "linear_v"), # Grok1 style
]
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend(
FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name=gate_proj,
ckpt_down_proj_name=down_proj,
ckpt_up_proj_name=up_proj,
num_experts=self.model_config.get_num_experts(),
num_redundant_experts=0, # TODO: enable EPLB
))
return expert_mapping
def recursive_replace(self):
"""Initialize the MoE layers."""
text_config = self.text_config
# Positional arguments
num_experts = self.model_config.get_num_experts()
top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"],
None)
assert top_k is not None
hidden_size = text_config.hidden_size
intermediate_size = getattr_iter(
text_config, ["moe_intermediate_size", "intermediate_size"], None)
assert intermediate_size is not None
# If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE
num_experts_shared = getattr_iter(text_config, [
"num_experts_shared", "n_shared_experts", "moe_num_shared_experts"
], 0)
reduce_results = num_experts_shared == 0
def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`."""
class MLPWithAllReduce(mlp.__class__):
def forward(self, *args, **kwargs):
output = super().forward(*args, **kwargs)
return self.experts.maybe_all_reduce_tensor_model_parallel(
output)
mlp.__class__ = MLPWithAllReduce
# Unused kwargs since we use custom_routing_function:
# - `scoring_func` and `e_score_correction_bias` only used for grouped
# topk routing inside vLLM and are non-trivial to infer
# and hard code `use_grouped_topk=False`
# - `renormalize` passed anyway because it's easy to infer
# - `num_expert_group` and `topk_group` used for inferring expert
# placement strategy in FusedMoE
# - `apply_router_weight_on_input` is already applied in Transformers
renormalize = getattr(text_config, "norm_topk_prob", top_k > 1)
num_expert_group = getattr(text_config, "n_group", None)
topk_group = getattr(text_config, "topk_group", None)
# MoE activation function
activation = "silu"
wrapped_arch = self.config.architectures[0].lower()
if "gptoss" in wrapped_arch:
activation = "swigluoai"
elif "grok1" in wrapped_arch:
activation = "gelu"
# Expert mapping for `AutoWeightsLoader`
expert_mapping = self.get_expert_mapping()
# Configs
parallel_config = self.parallel_config
eplb_config = parallel_config.eplb_config
# Expert parallel load balancing kwargs
enable_eplb = parallel_config.enable_eplb
num_redundant_experts = eplb_config.num_redundant_experts
# Recursively fuse MoE layers
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
if (child_name == "experts"
and isinstance(child_module, nn.ModuleList)):
# Alias for readability
mlp = module
experts = child_module
# Do the experts have biases
has_bias = False
for experts_param_name, _ in experts.named_parameters():
if "bias" in experts_param_name:
has_bias = True
break
# Double check there are no shared experts
nonlocal reduce_results
if reduce_results:
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
reduce_results = False
break
# Replace experts module with FusedMoE
fused_experts = TransformersFusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=reduce_results,
renormalize=renormalize,
# Hard coded because topk happens in Transformers
use_grouped_topk=False,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=self.quant_config,
prefix=qual_name,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
has_bias=has_bias,
expert_mapping=expert_mapping,
)
mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts)
# If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled
if not reduce_results and (fused_experts.tp_size > 1
or fused_experts.ep_size > 1):
add_all_reduce(mlp)
else:
_recursive_replace(child_module, prefix=qual_name)
_recursive_replace(self.model, prefix="model")
# Continue with the replacement of layers in TransformersBase
super().recursive_replace()
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
pass
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile)
class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM,
TransformersForMultimodalLM):
pass

View File

@ -20,7 +20,7 @@ from typing import Optional, Union
import torch
from transformers import AutoModelForSequenceClassification
from vllm.attention import AttentionType
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
@ -29,6 +29,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile
from .transformers_moe import TransformersMoEBase
from .utils import WeightsMapper
@ -79,7 +80,9 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
self.padding_idx = self.text_config.pad_token_id
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
self,
attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
@ -90,14 +93,7 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
self.check_version("4.57.0.dev0", "encoder models support")
return super().create_attention_instances(attn_type)
@ -198,3 +194,15 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
vllm_config.model_config),
),
})
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(TransformersMoEBase,
TransformersEmbeddingModel):
pass
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
TransformersMoEBase, TransformersForSequenceClassification):
pass