[Platform] Custom ops support for FusedMoe (#22509)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-08-13 19:12:00 +08:00 committed by GitHub
parent d94e3026de
commit 0b1bdac6af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 8 deletions

View File

@ -682,7 +682,8 @@ def determine_expert_map(
return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module):
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /

View File

@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(torch.nn.Module):
class LinearBase(CustomOp):
"""Base linear layer.
Args:
@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
prefix=prefix)
self.return_bias = return_bias
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
raise NotImplementedError
@CustomOp.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
param[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
@CustomOp.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
return s
@CustomOp.register("qkv_cross_parallel_linear")
class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@ -159,7 +160,8 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module):
@CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to