mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 06:25:48 +08:00
[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)
This commit is contained in:
parent
93cabc417c
commit
b471092d3a
@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoeBackend,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -324,7 +325,10 @@ def test_fp8_reloading(
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
# Fp8LinearMethod uses use_marlin
|
||||
# Fp8MoEMethod uses fp8_backend
|
||||
method.use_marlin = use_marlin
|
||||
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
|
||||
|
||||
# capture weights format during loading
|
||||
original_metadata = [
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@ -39,6 +40,7 @@ if has_triton_kernels():
|
||||
def _get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_fp8_w8a16: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
@ -50,6 +52,8 @@ def _get_config_dtype_str(
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_fp8_w8a16:
|
||||
return "fp8_w8a16"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
@ -319,6 +323,10 @@ class FusedMoEQuantConfig:
|
||||
def use_int8_w8a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == torch.int8
|
||||
|
||||
@property
|
||||
def use_fp8_w8a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == current_platform.fp8_dtype()
|
||||
|
||||
@property
|
||||
def use_int4_w4a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == "int4"
|
||||
@ -362,6 +370,7 @@ class FusedMoEQuantConfig:
|
||||
"""
|
||||
return _get_config_dtype_str(
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_fp8_w8a16=self.use_fp8_w8a16,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
ocp_mx_scheme=self.ocp_mx_scheme,
|
||||
@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config(
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int4 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def fp8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and fp8 weights.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc(
|
||||
current_platform.fp8_dtype(), group_shape, w1_scale, None, None
|
||||
),
|
||||
_w2=FusedMoEQuantDesc(
|
||||
current_platform.fp8_dtype(), group_shape, w2_scale, None, None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def int8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config(
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int8 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
|
||||
@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_moe_intermediate_size,
|
||||
marlin_quant_input,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
is_k_full: bool = True,
|
||||
):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
|
||||
"Supports only mxfp4_w4a16 or int4_w4a16"
|
||||
)
|
||||
assert (
|
||||
quant_config.use_mxfp4_w4a16
|
||||
or quant_config.use_int4_w4a16
|
||||
or quant_config.use_fp8_w8a16
|
||||
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
|
||||
self.w13_g_idx = w13_g_idx
|
||||
self.w2_g_idx = w2_g_idx
|
||||
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
|
||||
@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@property
|
||||
def quant_type_id(self) -> int:
|
||||
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
|
||||
return (
|
||||
scalar_types.uint4b8.id
|
||||
if self.quant_config.use_int4_w4a16
|
||||
else scalar_types.float4_e2m1f.id
|
||||
)
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
return scalar_types.uint4b8.id
|
||||
elif self.quant_config.use_mxfp4_w4a16:
|
||||
return scalar_types.float4_e2m1f.id
|
||||
elif (
|
||||
self.quant_config.use_fp8_w8a16
|
||||
and current_platform.fp8_dtype() == torch.float8_e4m3fn
|
||||
):
|
||||
return scalar_types.float8_e4m3fn.id
|
||||
else:
|
||||
raise NotImplementedError("Unsupported quantization type.")
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
ops.moe_sum(input, output)
|
||||
|
||||
|
||||
def modular_marlin_fused_moe(
|
||||
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config),
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
layer.w13_weight.data = w13_weight.data
|
||||
|
||||
if self.use_marlin:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self.use_inplace = False
|
||||
|
||||
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
]:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
moe_kernel = (
|
||||
MarlinExperts(quant_config=self.moe_quant_config)
|
||||
if use_marlin
|
||||
else TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
allow_deep_gemm=allow_deep_gemm,
|
||||
)
|
||||
)
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(), moe_kernel
|
||||
)
|
||||
self.use_inplace = True
|
||||
|
||||
@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
current_platform.is_xpu()
|
||||
or self.rocm_aiter_moe_enabled
|
||||
or self.use_marlin
|
||||
self.rocm_aiter_moe_enabled
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
assert (
|
||||
self.fp8_backend != Fp8MoeBackend.MARLIN
|
||||
) and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
|
||||
@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=(
|
||||
@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
# TODO(rob): convert this to MK.
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
x,
|
||||
@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# Rushuffle weights for MARLIN if needed.
|
||||
if self.use_marlin:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user