From 97000a2be7e318be1a3eb172f9abf2d67dbe73bf Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:45:55 -0500 Subject: [PATCH] [Bug] Fix compressed tensor not using deepgemm (#30820) Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 - .../compressed_tensors/compressed_tensors_moe.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 20782e2712f27..37f8e7780f999 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1696,7 +1696,6 @@ def fused_experts( and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) ): assert quant_config is not None - assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c302e465aedb7..fc359a3067a9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -96,6 +96,7 @@ from vllm.utils.deep_gemm import ( get_col_major_tma_aligned_tensor, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, + is_deep_gemm_supported, ) from vllm.utils.import_utils import has_deep_gemm @@ -716,6 +717,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): get_marlin_input_dtype(layer_name) if self.use_marlin else None ) + self.allow_deep_gemm = ( + self.block_quant + and envs.VLLM_MOE_USE_DEEP_GEMM + and is_deep_gemm_supported() + and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout() + ) + def create_weights( self, layer: torch.nn.Module, @@ -1231,6 +1239,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): if self.disable_expert_map else layer.expert_map, # ??? quant_config=self.moe_quant_config, + allow_deep_gemm=self.allow_deep_gemm, ) else: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( @@ -1272,6 +1281,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, quant_config=self.moe_quant_config, + allow_deep_gemm=self.allow_deep_gemm, ) @property