mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (#6289)
This commit is contained in:
parent
3dee97b05f
commit
ec9933f4a5
@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(x, layer.w13_weight, layer.w2_weight,
|
||||
router_logits, top_k, renormalize,
|
||||
use_grouped_topk, num_expert_group, topk_group)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
return fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
@ -82,6 +100,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
@ -279,10 +279,6 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
return model.eval()
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user