mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 (#28574)
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
dbbe0c756a
commit
7e082bc14e
@ -499,11 +499,35 @@ def batched_fused_marlin_moe(
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
w13_g_idx: torch.Tensor | None = None,
|
||||
w2_g_idx: torch.Tensor | None = None,
|
||||
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
|
||||
"Supports only mxfp4_w4a16 or int4_w4a16"
|
||||
)
|
||||
self.w13_g_idx = w13_g_idx
|
||||
self.w2_g_idx = w2_g_idx
|
||||
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
|
||||
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
|
||||
self.is_k_full = is_k_full
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def quant_type_id(self) -> int:
|
||||
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
|
||||
return (
|
||||
scalar_types.uint4b8.id
|
||||
if self.quant_config.use_int4_w4a16
|
||||
else scalar_types.float4_e2m1f.id
|
||||
)
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -533,8 +557,23 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
|
||||
class MarlinExperts(MarlinExpertsBase):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
w13_g_idx: torch.Tensor | None = None,
|
||||
w2_g_idx: torch.Tensor | None = None,
|
||||
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
quant_config,
|
||||
w13_g_idx,
|
||||
w2_g_idx,
|
||||
w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices,
|
||||
is_k_full,
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
@ -616,7 +655,7 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
gating_output=None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
||||
quant_type_id=self.quant_type_id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
@ -628,6 +667,11 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
# output buffer allocation. Please refer to workspace_shapes().
|
||||
intermediate_cache13=workspace2,
|
||||
intermediate_cache2=workspace13,
|
||||
g_idx1=self.w13_g_idx,
|
||||
g_idx2=self.w2_g_idx,
|
||||
sort_indices1=self.w13_g_idx_sort_indices,
|
||||
sort_indices2=self.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
@ -650,8 +694,20 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
w13_g_idx: torch.Tensor | None = None,
|
||||
w2_g_idx: torch.Tensor | None = None,
|
||||
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
super().__init__(
|
||||
quant_config,
|
||||
w13_g_idx,
|
||||
w2_g_idx,
|
||||
w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices,
|
||||
is_k_full,
|
||||
)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@ -720,7 +776,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
w1_scale=self.w1_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
gating_output=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
||||
quant_type_id=self.quant_type_id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
@ -728,4 +784,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
output=output,
|
||||
intermediate_cache13=workspace13,
|
||||
intermediate_cache2=workspace2,
|
||||
g_idx1=self.w13_g_idx,
|
||||
g_idx2=self.w2_g_idx,
|
||||
sort_indices1=self.w13_g_idx_sort_indices,
|
||||
sort_indices2=self.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
@ -35,7 +35,11 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
@ -1578,7 +1582,51 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
if self.num_bits != 4:
|
||||
return None
|
||||
return int4_w4a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, self.group_size],
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.num_bits == 4, "only supporting w4"
|
||||
layer.w13_weight = layer.w13_weight_packed
|
||||
layer.w2_weight = layer.w2_weight_packed
|
||||
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=layer.w13_weight_g_idx,
|
||||
w2_g_idx=layer.w2_weight_g_idx,
|
||||
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
return MarlinExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=layer.w13_weight_g_idx,
|
||||
w2_g_idx=layer.w2_weight_g_idx,
|
||||
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user