mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:54:36 +08:00
[Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064)
This commit is contained in:
parent
fa1970201d
commit
a1448b4b69
@ -6,6 +6,10 @@ import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@ -399,9 +402,7 @@ def make_prepare_finalize(
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
moe, quant_config
|
||||
)
|
||||
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
|
||||
@ -25,7 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEModularMethod
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@ -5,9 +5,11 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
|
||||
160
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
160
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_pplx
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (
|
||||
DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
|
||||
def maybe_roundup_layer_hidden_size(
|
||||
hidden_size: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
hidden_size: Layer hidden-size
|
||||
act_dtype: Data type of the layer activations.
|
||||
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs
|
||||
and all2all backend.
|
||||
Original hidden size otherwise.
|
||||
"""
|
||||
if moe_parallel_config.use_deepep_ht_kernels:
|
||||
hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size, act_dtype
|
||||
)
|
||||
|
||||
if moe_parallel_config.use_deepep_ll_kernels:
|
||||
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
# TODO: could allow this now
|
||||
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=all2all_manager.rank,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
num_dispatchers = (
|
||||
all2all_manager.world_size // all2all_manager.tp_group.world_size
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
if not all2all_manager.internode:
|
||||
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_local_experts=moe.num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict()
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||
handle,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
rank_expert_offset=all2all_manager.rank * moe.num_local_experts,
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert quant_config is not None
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
num_global_experts=moe.num_experts,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
|
||||
)
|
||||
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
112
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Normal file
112
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
"""
|
||||
Returns True if this quantization method uses 'weight_scale_2' pattern
|
||||
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
|
||||
|
||||
This method should be overridden by subclasses that use the
|
||||
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
|
||||
"""
|
||||
return False
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
from .all2all_utils import maybe_make_prepare_finalize
|
||||
|
||||
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
164
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
Normal file
164
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
Normal file
@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("modular_fused_moe")
|
||||
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def __init__(
|
||||
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.fused_experts = experts
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.fused_experts.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
moe_layer: torch.nn.Module,
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return self.fused_experts.prepare_finalize.topk_indices_dtype()
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return self.old_quant_method.supports_eplb
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return self.old_quant_method.allow_inplace
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return self.moe_quant_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# Is getattr needed?
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
if enable_eplb:
|
||||
if self.supports_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"EPLB is not supported for "
|
||||
f"{self.old_quant_method.__class__.__name__}."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
)
|
||||
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.allow_inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
File diff suppressed because it is too large
Load Diff
@ -38,7 +38,7 @@ class SharedFusedMoE(FusedMoE):
|
||||
and not (
|
||||
# TODO(wentao): find the root cause and remove this condition
|
||||
self.enable_eplb
|
||||
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
@ -0,0 +1,578 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
biased_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
|
||||
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
self.flashinfer_cutlass_moe_enabled = (
|
||||
has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and self.moe.moe_parallel_config.use_ep
|
||||
and self.moe.moe_parallel_config.dp_size == 1
|
||||
and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
|
||||
|
||||
self.flashinfer_cutlass_moe = partial(
|
||||
flashinfer_cutlass_moe,
|
||||
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
tp_rank=self.moe.moe_parallel_config.tp_rank,
|
||||
tp_size=self.moe.moe_parallel_config.tp_size,
|
||||
ep_rank=self.moe.moe_parallel_config.ep_rank,
|
||||
ep_size=self.moe.moe_parallel_config.ep_size,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.moe.moe_parallel_config.use_ep
|
||||
and self.moe.moe_parallel_config.dp_size == 1
|
||||
):
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is available for EP"
|
||||
" but not enabled, consider setting"
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
elif self.moe.moe_parallel_config.dp_size > 1:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP.",
|
||||
scope="local",
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.moe.is_act_and_mul:
|
||||
w13_up_dim = 2 * intermediate_size_per_partition
|
||||
else:
|
||||
w13_up_dim = intermediate_size_per_partition
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_up_dim,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (
|
||||
envs.VLLM_ROCM_MOE_PADDING
|
||||
and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0
|
||||
):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
# Padding the weight for better performance on ROCm
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||
|
||||
if current_platform.is_xpu():
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
use_prepack=True,
|
||||
experts_start_id=ep_rank_start,
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
|
||||
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
|
||||
dtype_w13 = layer.w13_weight.dtype
|
||||
_, n_w13, k_w13 = layer.w13_weight.size()
|
||||
dtype_w2 = layer.w2_weight.dtype
|
||||
_, n_w2, k_w2 = layer.w2_weight.size()
|
||||
if (
|
||||
envs.VLLM_CPU_SGL_KERNEL
|
||||
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
|
||||
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
|
||||
):
|
||||
packed_w13_weight = torch.ops._C.convert_weight_packed(
|
||||
layer.w13_weight
|
||||
)
|
||||
assert packed_w13_weight.size() == layer.w13_weight.size()
|
||||
layer.w13_weight.copy_(packed_w13_weight)
|
||||
del packed_w13_weight
|
||||
packed_w2_weight = torch.ops._C.convert_weight_packed(
|
||||
layer.w2_weight
|
||||
)
|
||||
assert packed_w2_weight.size() == layer.w2_weight.size()
|
||||
layer.w2_weight.copy_(packed_w2_weight)
|
||||
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
result = self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif self.flashinfer_cutlass_moe_enabled:
|
||||
return self.flashinfer_cutlass_moe(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for CPU.")
|
||||
return layer.cpu_fused_moe(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for XPU.")
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for TPU."
|
||||
)
|
||||
if e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for TPU."
|
||||
)
|
||||
assert activation == "silu", f"{activation} is not supported for TPU."
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU."
|
||||
)
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for TPU.")
|
||||
return fused_moe_pallas(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk=top_k,
|
||||
gating_output=router_logits,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
forward_native = forward_tpu
|
||||
elif current_platform.is_cpu():
|
||||
forward_native = forward_cpu
|
||||
elif current_platform.is_xpu():
|
||||
forward_native = forward_xpu
|
||||
else:
|
||||
forward_native = forward_cuda
|
||||
@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
self.w13_weight = w13_weight
|
||||
self.w2_weight = w2_weight
|
||||
layer.w13_weight = w13_weight
|
||||
layer.w2_weight = w2_weight
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"EP batched experts format"
|
||||
)
|
||||
else:
|
||||
layer.w13_weight = (
|
||||
self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None
|
||||
else layer.w13_weight
|
||||
)
|
||||
layer.w2_weight = (
|
||||
self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None
|
||||
else layer.w2_weight
|
||||
)
|
||||
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
w1=self.w13_weight,
|
||||
w2=self.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user