From 7e082bc14e431af0311186dd18d4f4da7a757f3a Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Wed, 12 Nov 2025 21:40:45 -0800 Subject: [PATCH] Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 (#28574) Signed-off-by: Lu Fang --- .../layers/fused_moe/fused_marlin_moe.py | 75 +++++++++++++++++-- .../compressed_tensors_moe.py | 52 ++++++++++++- 2 files changed, 118 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3b0df6c416a0..0b0f59f67318 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -499,11 +499,35 @@ def batched_fused_marlin_moe( class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): + def __init__( + self, + quant_config: FusedMoEQuantConfig, + w13_g_idx: torch.Tensor | None = None, + w2_g_idx: torch.Tensor | None = None, + w13_g_idx_sort_indices: torch.Tensor | None = None, + w2_g_idx_sort_indices: torch.Tensor | None = None, + is_k_full: bool = True, + ): # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, ( + "Supports only mxfp4_w4a16 or int4_w4a16" + ) + self.w13_g_idx = w13_g_idx + self.w2_g_idx = w2_g_idx + self.w13_g_idx_sort_indices = w13_g_idx_sort_indices + self.w2_g_idx_sort_indices = w2_g_idx_sort_indices + self.is_k_full = is_k_full super().__init__(quant_config) + @property + def quant_type_id(self) -> int: + # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 + return ( + scalar_types.uint4b8.id + if self.quant_config.use_int4_w4a16 + else scalar_types.float4_e2m1f.id + ) + def moe_problem_size( self, a1: torch.Tensor, @@ -533,8 +557,23 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): class MarlinExperts(MarlinExpertsBase): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + def __init__( + self, + quant_config: FusedMoEQuantConfig, + w13_g_idx: torch.Tensor | None = None, + w2_g_idx: torch.Tensor | None = None, + w13_g_idx_sort_indices: torch.Tensor | None = None, + w2_g_idx_sort_indices: torch.Tensor | None = None, + is_k_full: bool = True, + ): + super().__init__( + quant_config, + w13_g_idx, + w2_g_idx, + w13_g_idx_sort_indices, + w2_g_idx_sort_indices, + is_k_full, + ) def supports_expert_map(self) -> bool: return True @@ -616,7 +655,7 @@ class MarlinExperts(MarlinExpertsBase): gating_output=None, topk_weights=topk_weights, topk_ids=topk_ids, - quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + quant_type_id=self.quant_type_id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, @@ -628,6 +667,11 @@ class MarlinExperts(MarlinExpertsBase): # output buffer allocation. Please refer to workspace_shapes(). intermediate_cache13=workspace2, intermediate_cache2=workspace13, + g_idx1=self.w13_g_idx, + g_idx2=self.w2_g_idx, + sort_indices1=self.w13_g_idx_sort_indices, + sort_indices2=self.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, ) def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: @@ -650,8 +694,20 @@ class BatchedMarlinExperts(MarlinExpertsBase): max_num_tokens: int, num_dispatchers: int, quant_config: FusedMoEQuantConfig, + w13_g_idx: torch.Tensor | None = None, + w2_g_idx: torch.Tensor | None = None, + w13_g_idx_sort_indices: torch.Tensor | None = None, + w2_g_idx_sort_indices: torch.Tensor | None = None, + is_k_full: bool = True, ): - super().__init__(quant_config) + super().__init__( + quant_config, + w13_g_idx, + w2_g_idx, + w13_g_idx_sort_indices, + w2_g_idx_sort_indices, + is_k_full, + ) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -720,7 +776,7 @@ class BatchedMarlinExperts(MarlinExpertsBase): w1_scale=self.w1_scale, w2_scale=self.w2_scale, gating_output=None, - quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + quant_type_id=self.quant_type_id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, @@ -728,4 +784,9 @@ class BatchedMarlinExperts(MarlinExpertsBase): output=output, intermediate_cache13=workspace13, intermediate_cache2=workspace2, + g_idx1=self.w13_g_idx, + g_idx2=self.w2_g_idx, + sort_indices1=self.w13_g_idx_sort_indices, + sort_indices2=self.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bda94cee9e42..06ee96d55419 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -35,7 +35,11 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + fused_marlin_moe, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, @@ -1578,7 +1582,51 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return None + if self.num_bits != 4: + return None + return int4_w4a16_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.num_bits == 4, "only supporting w4" + layer.w13_weight = layer.w13_weight_packed + layer.w2_weight = layer.w2_weight_packed + assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) + assert self.moe_quant_config is not None + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + w13_g_idx=layer.w13_weight_g_idx, + w2_g_idx=layer.w2_weight_g_idx, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + return MarlinExperts( + quant_config=self.moe_quant_config, + w13_g_idx=layer.w13_weight_g_idx, + w2_g_idx=layer.w2_weight_g_idx, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) def apply( self,