[Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (#6289)

This commit is contained in:
Woosuk Kwon 2024-07-15 12:02:14 -07:00 committed by GitHub
parent 3dee97b05f
commit ec9933f4a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 17 deletions

View File

@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
def forward_cuda(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
w1,
w2,
router_logits,
top_k,
renormalize=renormalize,
@ -82,6 +100,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
topk_group=topk_group)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.

View File

@ -279,10 +279,6 @@ class DefaultModelLoader(BaseModelLoader):
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()