mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 04:34:28 +08:00
[MoE Refactor][7/N] AITER MK (#31102)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
6d518ffbaa
commit
b57b967386
@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||
@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
elif (
|
||||
hidden_states.dtype == torch.float8_e4m3fn
|
||||
or hidden_states.dtype == torch.float8_e4m3fnuz
|
||||
):
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(self, defer_input_quant: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.defer_input_quant = defer_input_quant
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if self.defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_scale,
|
||||
|
||||
@ -5,11 +5,15 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
|
||||
|
||||
class QuantMethod(IntEnum):
|
||||
@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
)
|
||||
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
def supports_chunking(self):
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Workspaces are managed internally by AITER.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is None
|
||||
assert a2_scale is None
|
||||
assert expert_tokens_meta is None
|
||||
|
||||
result = rocm_aiter_fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
assert result.shape == output.shape
|
||||
output.copy_(result)
|
||||
|
||||
@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
|
||||
DEEPGEMM = 3
|
||||
MARLIN = 4
|
||||
TRITON = 5
|
||||
AITER = 6
|
||||
|
||||
|
||||
def get_fp8_moe_backend(
|
||||
@ -189,6 +190,10 @@ def get_fp8_moe_backend(
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = config
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
|
||||
# with the changes to defer input quantization
|
||||
FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.AITER,
|
||||
]:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
moe_kernel = (
|
||||
MarlinExperts(quant_config=self.moe_quant_config)
|
||||
if use_marlin
|
||||
else TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=allow_deep_gemm,
|
||||
)
|
||||
)
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(), moe_kernel
|
||||
)
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
AiterExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
else:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
)
|
||||
self.use_inplace = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
self.fp8_backend == Fp8MoeBackend.AITER
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.fp8_backend != Fp8MoeBackend.MARLIN
|
||||
) and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
|
||||
raise NotImplementedError(
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
result = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
# Reshuffle weights for AITER if needed.
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# Rushuffle weights for MARLIN if needed.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user