[Platform] Custom ops support for FusedMoe (#22509)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-08-13 19:12:00 +08:00 committed by GitHub
parent d94e3026de
commit 0b1bdac6af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 8 deletions

View File

@ -682,7 +682,8 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module): @CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj / This layer contains both MergedColumnParallel weights (gate_up_proj /

View File

@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(CustomOp):
"""Base linear layer. """Base linear layer.
Args: Args:
@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
prefix=prefix) prefix=prefix)
self.return_bias = return_bias self.return_bias = return_bias
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
raise NotImplementedError
@CustomOp.register("replicated_linear")
class ReplicatedLinear(LinearBase): class ReplicatedLinear(LinearBase):
"""Replicated linear layer. """Replicated linear layer.
@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
param[shard_offset:shard_offset + shard_size] = loaded_weight param[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
@CustomOp.register("row_parallel_linear")
class RowParallelLinear(LinearBase): class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
return s return s
@CustomOp.register("qkv_cross_parallel_linear")
class QKVCrossParallelLinear(LinearBase): class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation. """Linear layers for efficient cross-attention's QKV transformation.

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@ -159,7 +160,8 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module): @CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to Adapted from torch.nn.Embedding, note that we pad the vocabulary size to