[Bug] Fix compressed tensor not using deepgemm (#30820)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-12-18 14:45:55 -05:00 committed by GitHub
parent d2dc5dfc6e
commit 97000a2be7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 1 deletions

View File

@ -1696,7 +1696,6 @@ def fused_experts(
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
):
assert quant_config is not None
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,

View File

@ -96,6 +96,7 @@ from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_gemm
@ -716,6 +717,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.allow_deep_gemm = (
self.block_quant
and envs.VLLM_MOE_USE_DEEP_GEMM
and is_deep_gemm_supported()
and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
)
def create_weights(
self,
layer: torch.nn.Module,
@ -1231,6 +1239,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.disable_expert_map
else layer.expert_map, # ???
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
@ -1272,6 +1281,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
@property