mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
d94e3026de
commit
0b1bdac6af
@ -682,7 +682,8 @@ def determine_expert_map(
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
class LinearBase(CustomOp):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("qkv_cross_parallel_linear")
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@ -159,7 +160,8 @@ def get_masked_input_and_mask(
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
@CustomOp.register("vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user