mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 00:54:00 +08:00
[Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123)
This commit is contained in:
parent
e4ee658672
commit
938772af03
@ -266,14 +266,14 @@ class DeviceCommunicatorBase:
|
||||
module
|
||||
for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.init_prepare_finalize?
|
||||
# presence of quant_method.maybe_init_modular_kernel?
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
module.maybe_init_modular_kernel()
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
|
||||
@ -117,10 +117,8 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.moe = moe
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.fused_experts: FusedMoEModularKernel | None = None
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
@ -245,9 +243,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
# Note: init_prepare_finalize should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||
def maybe_init_modular_kernel(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEModularKernel | None:
|
||||
assert self.moe is not None
|
||||
|
||||
# We must get the quant config here so that the layer is
|
||||
@ -261,17 +259,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
logger.debug(
|
||||
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
||||
)
|
||||
assert self.topk_indices_dtype is None
|
||||
assert self.fused_experts is None, (
|
||||
f"Attempt to override experts for {id(self)}!"
|
||||
)
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, layer)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
return FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
layer.shared_experts,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@ -292,8 +287,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def using_modular_kernel(self) -> bool:
|
||||
return self.fused_experts is not None
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
@ -322,6 +325,138 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("modular_fused_moe")
|
||||
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
fused_experts: FusedMoEModularKernel,
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
# Find better way to copy attributes? Should we even copy attributes?
|
||||
# self.__dict__.update(old_quant_method.__dict__)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.fused_experts = fused_experts
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not fused_experts.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return self.fused_experts.prepare_finalize.topk_indices_dtype()
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return self.old_quant_method.supports_eplb
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return self.old_quant_method.allow_inplace
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return self.moe_quant_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# Is getattr needed?
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
if enable_eplb:
|
||||
if self.supports_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"EPLB is not supported for "
|
||||
f"{self.old_quant_method.__class__.__name__}."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
)
|
||||
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.allow_inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
@ -378,6 +513,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
@ -650,7 +793,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert self.fused_experts is None
|
||||
result = self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -671,21 +813,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif self.fused_experts is not None:
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
else:
|
||||
assert fused_experts is not None
|
||||
result = fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -1267,7 +1395,7 @@ class FusedMoE(CustomOp):
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
@ -1279,24 +1407,26 @@ class FusedMoE(CustomOp):
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
)
|
||||
self.moe_config: FusedMoEConfig = moe
|
||||
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def _get_quant_method() -> FusedMoEMethodBase:
|
||||
"""
|
||||
Helper method to ensure self.quant_method is never None and
|
||||
of the proper type.
|
||||
"""
|
||||
quant_method = None
|
||||
if self.quant_config is not None:
|
||||
quant_method = self.quant_config.get_quant_method(self, prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
return quant_method
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: QuantizeMethodBase | None = None
|
||||
quant_method = (
|
||||
UnquantizedFusedMoEMethod(moe)
|
||||
if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix)
|
||||
)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
if not self.moe_config.is_act_and_mul:
|
||||
# Avoid circular import
|
||||
@ -1305,7 +1435,7 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
|
||||
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for unquantized "
|
||||
@ -1316,20 +1446,18 @@ class FusedMoE(CustomOp):
|
||||
"is_act_and_mul=False is supported only for CUDA for now"
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
|
||||
|
||||
if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError(
|
||||
"EPLB is only supported for FP8 quantization for now."
|
||||
)
|
||||
if self.enable_eplb and not self.quant_method.supports_eplb:
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError(
|
||||
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
|
||||
"EPLB is only supported for FP8 quantization for now."
|
||||
)
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
@ -1353,6 +1481,15 @@ class FusedMoE(CustomOp):
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
# Note: maybe_init_modular_kernel should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
# This is called after all weight loading and post-processing, so it
|
||||
# should be safe to swap out the quant_method.
|
||||
def maybe_init_modular_kernel(self) -> None:
|
||||
mk = self.quant_method.maybe_init_modular_kernel(self)
|
||||
if mk is not None:
|
||||
self.quant_method = FusedMoEModularMethod(self.quant_method, mk)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
@ -2167,7 +2304,7 @@ class FusedMoE(CustomOp):
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.fused_experts is not None
|
||||
isinstance(self.quant_method, FusedMoEModularMethod)
|
||||
and self.quant_method.fused_experts.output_is_reduced()
|
||||
)
|
||||
|
||||
@ -2403,7 +2540,7 @@ class FusedMoE(CustomOp):
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
||||
not isinstance(self.quant_method, FusedMoEModularMethod)
|
||||
and self.shared_experts is not None
|
||||
)
|
||||
|
||||
@ -2430,8 +2567,8 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits, has_separate_shared_experts
|
||||
)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1 and not self.quant_method.using_modular_kernel
|
||||
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
|
||||
self.quant_method, FusedMoEModularMethod
|
||||
)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
|
||||
@ -707,6 +707,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.activation_formats[0]}"
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
|
||||
@ -617,8 +617,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
|
||||
|
||||
|
||||
@ -518,12 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@ -462,12 +462,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -488,24 +483,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
elif self.fused_experts is not None:
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight
|
||||
), "Flashinfer CUTLASS Fused MoE not applicable!"
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# FlashInfer fused experts path
|
||||
elif self.allow_flashinfer:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
@ -1066,13 +1043,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# cutlass fp8 or fused_experts but not marlin or rocm.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert self.fused_experts is None
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1098,7 +1070,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.fused_experts is None
|
||||
return rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -1111,18 +1082,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
elif self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
)
|
||||
|
||||
# cutlass path
|
||||
elif self.use_cutlass:
|
||||
assert self.moe_quant_config is not None
|
||||
@ -1318,8 +1277,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
|
||||
@ -1636,8 +1593,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
|
||||
@ -1901,8 +1856,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
|
||||
|
||||
@ -158,8 +158,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
|
||||
|
||||
@ -703,9 +703,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
|
||||
self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore
|
||||
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
@ -1181,6 +1178,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1210,10 +1215,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
if (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and self.fused_experts is None
|
||||
):
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
@ -1290,10 +1292,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order of checks is important since self.fused_experts
|
||||
# can override fused_experts or cutlass but not rocm or marlin.
|
||||
#
|
||||
topk_weights, topk_ids, zero_expert_result = select_result
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -1301,7 +1299,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
assert self.fused_experts is None
|
||||
result = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1315,7 +1312,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert self.fused_experts is None
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1333,19 +1329,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.fused_experts:
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not self.block_quant
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
|
||||
@ -585,8 +585,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
|
||||
|
||||
|
||||
@ -742,8 +742,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
|
||||
|
||||
@ -18,9 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
@ -605,7 +602,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert self.fused_experts is None
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
@ -638,24 +634,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# cutlass or fused_experts.
|
||||
#
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
@ -1647,8 +1626,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
assert self.fused_experts is None
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = (
|
||||
flashinfer.fp4_quantize(
|
||||
@ -1720,13 +1697,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
|
||||
# trtllm.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1747,23 +1718,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
elif self.fused_experts is not None:
|
||||
assert (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight
|
||||
), "Flashinfer CUTLASS Fused MoE not applicable!"
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return self.fused_experts(
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
|
||||
@ -226,7 +226,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.moe = layer
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
@ -381,7 +380,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
|
||||
|
||||
|
||||
@ -197,8 +197,6 @@ class Mxfp4Config(QuantizationConfig):
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
@ -815,6 +813,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"EP batched experts format"
|
||||
)
|
||||
else:
|
||||
layer.w13_weight = (
|
||||
self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None
|
||||
else layer.w13_weight
|
||||
)
|
||||
layer.w2_weight = (
|
||||
self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None
|
||||
else layer.w2_weight
|
||||
)
|
||||
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
@ -838,71 +848,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
|
||||
)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
w13_weight = (
|
||||
self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None
|
||||
else layer.w13_weight
|
||||
)
|
||||
w2_weight = (
|
||||
self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight
|
||||
)
|
||||
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13_weight,
|
||||
w2=w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -930,29 +878,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@ -310,7 +310,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts,
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
@ -322,17 +321,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
self.fused_experts_func = None
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@ -369,8 +362,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
|
||||
@ -392,7 +383,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts_func(
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -403,7 +398,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
if self.use_marlin:
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
@ -421,22 +416,22 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert self.fused_experts_func is not None
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
@ -601,6 +596,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -624,8 +623,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
||||
|
||||
@ -377,8 +377,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@ -160,8 +160,8 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
):
|
||||
return False
|
||||
|
||||
if not isinstance(module.quant_method.fused_experts, FusedMoEModularKernel):
|
||||
# fused_experts could invoke deep_gemm_moe_fp8
|
||||
if not isinstance(module.quant_method, FusedMoEModularMethod):
|
||||
# modular kernels could invoke deep_gemm_moe_fp8
|
||||
return True
|
||||
|
||||
mk: FusedMoEModularKernel = module.quant_method.fused_experts
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user