Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 (#28574)

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-11-12 21:40:45 -08:00 committed by GitHub
parent dbbe0c756a
commit 7e082bc14e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 118 additions and 9 deletions

View File

@ -499,11 +499,35 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
"Supports only mxfp4_w4a16 or int4_w4a16"
)
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
super().__init__(quant_config)
@property
def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
return (
scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16
else scalar_types.float4_e2m1f.id
)
def moe_problem_size(
self,
a1: torch.Tensor,
@ -533,8 +557,23 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class MarlinExperts(MarlinExpertsBase):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)
def supports_expert_map(self) -> bool:
return True
@ -616,7 +655,7 @@ class MarlinExperts(MarlinExpertsBase):
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
@ -628,6 +667,11 @@ class MarlinExperts(MarlinExpertsBase):
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13=workspace2,
intermediate_cache2=workspace13,
g_idx1=self.w13_g_idx,
g_idx2=self.w2_g_idx,
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
@ -650,8 +694,20 @@ class BatchedMarlinExperts(MarlinExpertsBase):
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(quant_config)
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@ -720,7 +776,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
@ -728,4 +784,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
output=output,
intermediate_cache13=workspace13,
intermediate_cache2=workspace2,
g_idx1=self.w13_g_idx,
g_idx2=self.w2_g_idx,
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)

View File

@ -35,7 +35,11 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
@ -1578,7 +1582,51 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.num_bits != 4:
return None
return int4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply(
self,