diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 62203186510ce..c9ab24fd439c5 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import ( Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod, + Fp8MoeBackend, Fp8MoEMethod, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -324,7 +325,10 @@ def test_fp8_reloading( weight_loader=default_weight_loader, ) + # Fp8LinearMethod uses use_marlin + # Fp8MoEMethod uses fp8_backend method.use_marlin = use_marlin + method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None # capture weights format during loading original_metadata = [ diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a9a2990ca2b53..d581e91f36d03 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import cdiv @@ -39,6 +40,7 @@ if has_triton_kernels(): def _get_config_dtype_str( dtype: torch.dtype, use_fp8_w8a8: bool = False, + use_fp8_w8a16: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, @@ -50,6 +52,8 @@ def _get_config_dtype_str( """ if use_fp8_w8a8: return "fp8_w8a8" + elif use_fp8_w8a16: + return "fp8_w8a16" elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: @@ -319,6 +323,10 @@ class FusedMoEQuantConfig: def use_int8_w8a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == torch.int8 + @property + def use_fp8_w8a16(self) -> bool: + return self._a1.dtype is None and self._w1.dtype == current_platform.fp8_dtype() + @property def use_int4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "int4" @@ -362,6 +370,7 @@ class FusedMoEQuantConfig: """ return _get_config_dtype_str( use_fp8_w8a8=self.use_fp8_w8a8, + use_fp8_w8a16=self.use_fp8_w8a16, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, ocp_mx_scheme=self.ocp_mx_scheme, @@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config( ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int4 weights. - Note: Activations are pre-quantized. """ group_shape = GroupShape(*block_shape) if block_shape is not None else None return FusedMoEQuantConfig( @@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config( ) +def fp8_w8a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and fp8 weights. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc( + current_platform.fp8_dtype(), group_shape, w1_scale, None, None + ), + _w2=FusedMoEQuantDesc( + current_platform.fp8_dtype(), group_shape, w2_scale, None, None + ), + ) + + def int8_w8a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, @@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config( ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int8 weights. - Note: Activations are pre-quantized. """ group_shape = GroupShape(*block_shape) if block_shape is not None else None return FusedMoEQuantConfig( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 92d72b75656cd..295a2a28156ed 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, @@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_moe_intermediate_size, marlin_quant_input, ) +from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): is_k_full: bool = True, ): # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, ( - "Supports only mxfp4_w4a16 or int4_w4a16" - ) + assert ( + quant_config.use_mxfp4_w4a16 + or quant_config.use_int4_w4a16 + or quant_config.use_fp8_w8a16 + ), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16" self.w13_g_idx = w13_g_idx self.w2_g_idx = w2_g_idx self.w13_g_idx_sort_indices = w13_g_idx_sort_indices @@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): @property def quant_type_id(self) -> int: # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 - return ( - scalar_types.uint4b8.id - if self.quant_config.use_int4_w4a16 - else scalar_types.float4_e2m1f.id - ) + if self.quant_config.use_int4_w4a16: + return scalar_types.uint4b8.id + elif self.quant_config.use_mxfp4_w4a16: + return scalar_types.float4_e2m1f.id + elif ( + self.quant_config.use_fp8_w8a16 + and current_platform.fp8_dtype() == torch.float8_e4m3fn + ): + return scalar_types.float8_e4m3fn.id + else: + raise NotImplementedError("Unsupported quantization type.") def moe_problem_size( self, @@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase): ops.moe_sum(input, output) -def modular_marlin_fused_moe( - quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None -) -> mk.FusedMoEModularKernel: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config), - shared_experts, - ) - - class BatchedMarlinExperts(MarlinExpertsBase): def __init__( self, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4b2438133dd6b..1f770d6d89a13 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, fp8_w8a8_moe_quant_config, + fp8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.linear import ( LinearBase, @@ -95,7 +95,6 @@ from vllm.model_executor.parameter import ( ) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, is_deep_gemm_supported, @@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) self.marlin_input_dtype = None - self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM @@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) layer.w13_weight.data = w13_weight.data - if self.use_marlin: + if self.fp8_backend == Fp8MoeBackend.MARLIN: prepare_moe_fp8_layer_for_marlin( layer, False, input_dtype=self.marlin_input_dtype ) @@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) self.use_inplace = False - elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]: + elif self.fp8_backend in [ + Fp8MoeBackend.DEEPGEMM, + Fp8MoeBackend.TRITON, + Fp8MoeBackend.MARLIN, + ]: from vllm.model_executor.layers.fused_moe import ( TritonOrDeepGemmExperts, ) + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) @@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): config = self.get_fused_moe_quant_config(layer) assert config is not None self.moe_quant_config = config - self.kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrDeepGemmExperts( + use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN + allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM + moe_kernel = ( + MarlinExperts(quant_config=self.moe_quant_config) + if use_marlin + else TritonOrDeepGemmExperts( quant_config=self.moe_quant_config, - allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), - ), + allow_deep_gemm=allow_deep_gemm, + ) + ) + + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), moe_kernel ) self.use_inplace = True @@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: if ( - current_platform.is_xpu() - or self.rocm_aiter_moe_enabled - or self.use_marlin + self.rocm_aiter_moe_enabled + or self.fp8_backend == Fp8MoeBackend.MARLIN or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None @@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): TritonOrDeepGemmExperts, ) - assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( + assert ( + self.fp8_backend != Fp8MoeBackend.MARLIN + ) and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet." ) @@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.use_marlin: - return None + if self.fp8_backend == Fp8MoeBackend.MARLIN: + return fp8_w8a16_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + block_shape=self.weight_block_size, + ) return fp8_w8a8_moe_quant_config( w1_scale=( @@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) - elif self.use_marlin: - # TODO(rob): convert this to MK. - assert layer.activation == "silu", ( - f"{layer.activation} not supported for Marlin MoE." - ) - result = fused_marlin_moe( - x, - layer.w13_weight, - layer.w2_weight, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - input_dtype=self.marlin_input_dtype, - workspace=layer.workspace, - ) else: result = self.kernel( x, @@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): replace_parameter(layer, "w2_weight", shuffled_w2) # Rushuffle weights for MARLIN if needed. - if self.use_marlin: + if self.fp8_backend == Fp8MoeBackend.MARLIN: prepare_moe_fp8_layer_for_marlin( layer, False, input_dtype=self.marlin_input_dtype )