mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:34:58 +08:00
[ROCM] MoE fp4 CK kernel (#26545)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
99722d5f0e
commit
0925b28a8e
@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def use_mxfp4_aiter_moe() -> bool:
|
||||||
|
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled,
|
is_rocm_aiter_moe_enabled,
|
||||||
|
use_mxfp4_aiter_moe,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_moe_fp8_layer_for_marlin,
|
prepare_moe_fp8_layer_for_marlin,
|
||||||
@ -472,22 +473,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
"not implemented. Please open an issue."
|
"not implemented. Please open an issue."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not current_platform.supports_mx():
|
self.emulate = not current_platform.supports_mx() or not (
|
||||||
self.emulate = True
|
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||||
|
)
|
||||||
|
if self.emulate:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The current platform does not support native MXFP4/MXFP6 "
|
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||||
|
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
|
||||||
|
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||||
|
"does not support native MXFP4/MXFP6 "
|
||||||
"computation. Simulated weight dequantization and activation "
|
"computation. Simulated weight dequantization and activation "
|
||||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||||
"layers computed in high precision."
|
"layers computed in high precision."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.emulate = True
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The current platform supports native MXFP4/MXFP6 "
|
"The current mode supports native MoE MXFP4 computation"
|
||||||
"computation, but kernels are not yet integrated in vLLM. "
|
|
||||||
"Simulated weight dequantization and activation "
|
|
||||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
|
||||||
"layers computed in high precision."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||||
@ -568,6 +569,24 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
if self.emulate:
|
||||||
|
return
|
||||||
|
|
||||||
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||||
|
|
||||||
|
# Pre-shuffle weight scales
|
||||||
|
s0, s1, _ = layer.w13_weight_scale.shape
|
||||||
|
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||||
|
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||||
|
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||||
|
|
||||||
|
s0, s1, _ = layer.w2_weight_scale.shape
|
||||||
|
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||||
|
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||||
|
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
@ -611,8 +630,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -628,17 +645,44 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
|||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = fused_experts(
|
if not self.emulate:
|
||||||
x,
|
from aiter import ActivationType, QuantType
|
||||||
layer.w13_weight,
|
from aiter.fused_moe import fused_moe
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
aiter_acts = {
|
||||||
topk_ids=topk_ids,
|
ActivationType.No.name.lower(): ActivationType.No,
|
||||||
inplace=True,
|
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||||
activation=activation,
|
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||||
global_num_experts=global_num_experts,
|
}
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
assert activation in aiter_acts, (
|
||||||
expert_map=expert_map,
|
f"Aiter CK fp4 MoE doesn't support activation {activation}"
|
||||||
quant_config=self.moe_quant_config,
|
)
|
||||||
)
|
out = fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type=QuantType.per_1x32,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
activation=aiter_acts[activation],
|
||||||
|
doweight_stage1=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
out = fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user