[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)

This commit is contained in:
Robert Shaw 2025-12-21 12:37:42 -05:00 committed by GitHub
parent 93cabc417c
commit b471092d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 85 additions and 63 deletions

View File

@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -324,7 +325,10 @@ def test_fp8_reloading(
weight_loader=default_weight_loader,
)
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method.use_marlin = use_marlin
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
# capture weights format during loading
original_metadata = [

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_Scheme,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv
@ -39,6 +40,7 @@ if has_triton_kernels():
def _get_config_dtype_str(
dtype: torch.dtype,
use_fp8_w8a8: bool = False,
use_fp8_w8a16: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
@ -50,6 +52,8 @@ def _get_config_dtype_str(
"""
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_fp8_w8a16:
return "fp8_w8a16"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
@ -319,6 +323,10 @@ class FusedMoEQuantConfig:
def use_int8_w8a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == torch.int8
@property
def use_fp8_w8a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == current_platform.fp8_dtype()
@property
def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4"
@ -362,6 +370,7 @@ class FusedMoEQuantConfig:
"""
return _get_config_dtype_str(
use_fp8_w8a8=self.use_fp8_w8a8,
use_fp8_w8a16=self.use_fp8_w8a16,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
ocp_mx_scheme=self.ocp_mx_scheme,
@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config(
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int4 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config(
)
def fp8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and fp8 weights.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(
current_platform.fp8_dtype(), group_shape, w1_scale, None, None
),
_w2=FusedMoEQuantDesc(
current_platform.fp8_dtype(), group_shape, w2_scale, None, None
),
)
def int8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config(
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int8 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(

View File

@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size,
marlin_quant_input,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
is_k_full: bool = True,
):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
"Supports only mxfp4_w4a16 or int4_w4a16"
)
assert (
quant_config.use_mxfp4_w4a16
or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@property
def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
return (
scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16
else scalar_types.float4_e2m1f.id
)
if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id
elif self.quant_config.use_mxfp4_w4a16:
return scalar_types.float4_e2m1f.id
elif (
self.quant_config.use_fp8_w8a16
and current_platform.fp8_dtype() == torch.float8_e4m3fn
):
return scalar_types.float8_e4m3fn.id
else:
raise NotImplementedError("Unsupported quantization type.")
def moe_problem_size(
self,
@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase):
ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
self,

View File

@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
LinearBase,
@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.use_inplace = False
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
elif self.fp8_backend in [
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), moe_kernel
)
self.use_inplace = True
@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
current_platform.is_xpu()
or self.rocm_aiter_moe_enabled
or self.use_marlin
self.rocm_aiter_moe_enabled
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts,
)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
assert (
self.fp8_backend != Fp8MoeBackend.MARLIN
) and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet."
)
@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
block_shape=self.weight_block_size,
)
return fp8_w8a8_moe_quant_config(
w1_scale=(
@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
elif self.use_marlin:
# TODO(rob): convert this to MK.
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
else:
result = self.kernel(
x,
@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.use_marlin:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)