diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 78685538ea1b3..a86fb3d309525 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum -from functools import partial from typing import TYPE_CHECKING, Any, Optional import torch @@ -51,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, @@ -728,18 +726,28 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - if self.block_quant: - assert self.weight_block_size == [128, 128], ( - f"Only support weight_block_size == [128, 128], " - f"got {self.weight_block_size}" + if self.block_quant and self.weight_block_size != [128, 128]: + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend only supports block " + "size [128, 128]." + ) + if not self.block_quant: + if layer.renormalize or layer.custom_routing_function is not None: + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend does custom routing " + f"function or renormalization, but got {layer.renormalize} and " + f"{layer.custom_routing_function}." + ) + if layer.scoring_func != "sigmoid": + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend only supports " + f"'sigmoid' scoring function, but got {layer.scoring_func}." + ) + if layer.activation != "silu": + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " + "activation function, but got {layer.activation}." ) - self.flashinfer_moe_fn = partial( - flashinfer_cutlass_moe_fp8, - moe=self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - - self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM def create_weights( self, @@ -928,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm: + if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: dg_w13_weight, dg_w13_weight_scale_inv = ( deepgemm_post_process_fp8_weight_block( wq=layer.w13_weight.data, @@ -1039,6 +1047,61 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale + # NOTE(rob): this is a WIP refactor. We are first migrating + # all of the kernels in the TP case to use mk. Once this is + # done, then we will initialzie the TP case and DP/EP case + # via the same code path (i.e. via maybe_init_modular_kernel). + # NOTE(rob): in progress migrating all into this format. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferAllGatherMoEPrepareAndFinalize, + ) + + config = self.get_fused_moe_quant_config(layer) + assert config is not None + self.moe_quant_config = config + + self.kernel = mk.FusedMoEModularKernel( + FlashInferAllGatherMoEPrepareAndFinalize( + use_dp=(self.moe.dp_size > 1), + use_deepseek_fp8_block_scale=self.block_quant, + ), + FlashInferExperts( + out_dtype=torch.get_default_dtype(), + quant_config=self.moe_quant_config, + ep_rank=self.moe.ep_rank, + ep_size=self.moe.ep_size, + tp_rank=self.moe.tp_rank, + tp_size=self.moe.tp_size, + use_dp=(self.moe.dp_size > 1), + use_deepseek_fp8_block_scale=self.block_quant, + ), + ) + self.use_inplace = False + + elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]: + from vllm.model_executor.layers.fused_moe import ( + TritonOrDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, + ) + + config = self.get_fused_moe_quant_config(layer) + assert config is not None + self.moe_quant_config = config + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonOrDeepGemmExperts( + quant_config=self.moe_quant_config, + allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), + ), + ) + self.use_inplace = True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1091,7 +1154,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert max_num_tokens_per_rank is not None experts_impl = ( - BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts + BatchedDeepGemmExperts + if self.fp8_backend == Fp8MoeBackend.DEEPGEMM + else BatchedTritonExperts ) logger.debug( "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", @@ -1126,7 +1191,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) return TritonOrDeepGemmExperts( quant_config=self.moe_quant_config, - allow_deep_gemm=self.allow_deep_gemm, + allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), ) def get_fused_moe_quant_config( @@ -1164,6 +1229,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # TODO(rob): convert this to MK. if layer.enable_eplb: raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") assert layer.activation == "silu", ( @@ -1228,6 +1294,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): rocm_aiter_fused_experts, ) + # TODO(rob): convert this to MK. result = rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1240,6 +1307,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config=self.moe_quant_config, ) elif self.use_marlin: + # TODO(rob): convert this to MK. assert layer.activation == "silu", ( f"{layer.activation} not supported for Marlin MoE." ) @@ -1261,47 +1329,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert layer.activation == "silu", ( - f"Expected 'silu' activation but got {layer.activation}" - ) - if not self.block_quant: - assert ( - not layer.renormalize and layer.custom_routing_function is not None - ) - assert layer.scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {layer.scoring_func}" - ) - # Delegate to CUTLASS FlashInfer path; function already bound with - # use_deepseek_fp8_block_scale for block-quant when applicable - result = self.flashinfer_moe_fn( + else: + result = self.kernel( x, - layer, + layer.w13_weight, + layer.w2_weight, topk_weights, topk_ids, - inplace=False, + inplace=self.use_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - else: - from vllm.model_executor.layers.fused_moe import fused_experts - - result = fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - allow_deep_gemm=self.allow_deep_gemm, - ) if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: assert not isinstance(result, tuple), (