From 0b1bdac6af33b890a4d68321df05e71a1ba43dc4 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 13 Aug 2025 19:12:00 +0800 Subject: [PATCH] [Platform] Custom ops support for FusedMoe (#22509) Signed-off-by: wangxiyuan --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- vllm/model_executor/layers/linear.py | 12 ++++++------ .../layers/vocab_parallel_embedding.py | 4 +++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8ef0a805d86c..ddc02168e5c4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -682,7 +682,8 @@ def determine_expert_map( return (local_num_experts, expert_map) -class FusedMoE(torch.nn.Module): +@CustomOp.register("fused_moe") +class FusedMoE(CustomOp): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bb81a663d454..75391c51f775 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm @@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase): return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) -class LinearBase(torch.nn.Module): +class LinearBase(CustomOp): """Base linear layer. Args: @@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module): prefix=prefix) self.return_bias = return_bias - def forward( - self, x: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - raise NotImplementedError - +@CustomOp.register("replicated_linear") class ReplicatedLinear(LinearBase): """Replicated linear layer. @@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear): param[shard_offset:shard_offset + shard_size] = loaded_weight +@CustomOp.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear): param_data.copy_(loaded_weight) +@CustomOp.register("row_parallel_linear") class RowParallelLinear(LinearBase): """Linear layer with row parallelism. @@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase): return s +@CustomOp.register("qkv_cross_parallel_linear") class QKVCrossParallelLinear(LinearBase): """Linear layers for efficient cross-attention's QKV transformation. diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index a5f262c832bf..9f223998e554 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm @@ -159,7 +160,8 @@ def get_masked_input_and_mask( return input_, ~vocab_mask -class VocabParallelEmbedding(torch.nn.Module): +@CustomOp.register("vocab_parallel_embedding") +class VocabParallelEmbedding(CustomOp): """Embedding parallelized in the vocabulary dimension. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to