From 2e25bb12a8f3b6fd8c2edcb5ae21c6c1cd315dbe Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 2 Jul 2025 22:07:43 -0400 Subject: [PATCH] [Bugfix] Fix import of CutlassExpertsFp8 in compressed_tensors_moe.py (#20381) Signed-off-by: Bill Nell --- .../compressed_tensors/compressed_tensors_moe.py | 8 +++++--- vllm/model_executor/layers/quantization/fp8.py | 10 ++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa011266cf2fe..5d7e00c2b81be 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -14,9 +14,9 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, - FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts) + FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -570,6 +570,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): del layer.w2_input_scale self.fused_experts_func = None else: + from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts def apply( @@ -826,6 +827,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 use_batched_format = (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0295f5e2a1c88..f879d0ad091a8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,10 +14,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat, - FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, - TritonOrDeepGemmExperts) + FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -785,6 +784,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import ( + BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.")