mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
d94e3026de
commit
0b1bdac6af
@ -682,7 +682,8 @@ def determine_expert_map(
|
|||||||
return (local_num_experts, expert_map)
|
return (local_num_experts, expert_map)
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
@CustomOp.register("fused_moe")
|
||||||
|
class FusedMoE(CustomOp):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||||
@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(torch.nn.Module):
|
class LinearBase(CustomOp):
|
||||||
"""Base linear layer.
|
"""Base linear layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
|
|||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
self.return_bias = return_bias
|
self.return_bias = return_bias
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor
|
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("replicated_linear")
|
||||||
class ReplicatedLinear(LinearBase):
|
class ReplicatedLinear(LinearBase):
|
||||||
"""Replicated linear layer.
|
"""Replicated linear layer.
|
||||||
|
|
||||||
@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
|||||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("column_parallel_linear")
|
||||||
class ColumnParallelLinear(LinearBase):
|
class ColumnParallelLinear(LinearBase):
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("row_parallel_linear")
|
||||||
class RowParallelLinear(LinearBase):
|
class RowParallelLinear(LinearBase):
|
||||||
"""Linear layer with row parallelism.
|
"""Linear layer with row parallelism.
|
||||||
|
|
||||||
@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("qkv_cross_parallel_linear")
|
||||||
class QKVCrossParallelLinear(LinearBase):
|
class QKVCrossParallelLinear(LinearBase):
|
||||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
|||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||||
@ -159,7 +160,8 @@ def get_masked_input_and_mask(
|
|||||||
return input_, ~vocab_mask
|
return input_, ~vocab_mask
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbedding(torch.nn.Module):
|
@CustomOp.register("vocab_parallel_embedding")
|
||||||
|
class VocabParallelEmbedding(CustomOp):
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
"""Embedding parallelized in the vocabulary dimension.
|
||||||
|
|
||||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user