mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (#6289)
This commit is contained in:
parent
3dee97b05f
commit
ec9933f4a5
@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -61,7 +61,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
def apply(self,
|
def apply(
|
||||||
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@ -69,11 +70,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
topk_group: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.forward(x, layer.w13_weight, layer.w2_weight,
|
||||||
|
router_logits, top_k, renormalize,
|
||||||
|
use_grouped_topk, num_expert_group, topk_group)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||||
return fused_moe(x,
|
return fused_moe(x,
|
||||||
layer.w13_weight,
|
w1,
|
||||||
layer.w2_weight,
|
w2,
|
||||||
router_logits,
|
router_logits,
|
||||||
top_k,
|
top_k,
|
||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
@ -82,6 +100,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group)
|
topk_group=topk_group)
|
||||||
|
|
||||||
|
def forward_cpu(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The CPU backend currently does not support MoE.")
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|||||||
@ -279,10 +279,6 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
quant_method = getattr(module, "quant_method", None)
|
quant_method = getattr(module, "quant_method", None)
|
||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
# FIXME: Remove this after Mixtral is updated
|
|
||||||
# to use quant_method.
|
|
||||||
if hasattr(module, "process_weights_after_loading"):
|
|
||||||
module.process_weights_after_loading()
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user