mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 (#28574)
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
dbbe0c756a
commit
7e082bc14e
@ -499,11 +499,35 @@ def batched_fused_marlin_moe(
|
|||||||
|
|
||||||
|
|
||||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
w13_g_idx: torch.Tensor | None = None,
|
||||||
|
w2_g_idx: torch.Tensor | None = None,
|
||||||
|
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
is_k_full: bool = True,
|
||||||
|
):
|
||||||
# TODO (varun) : Enable activation quantization
|
# TODO (varun) : Enable activation quantization
|
||||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
|
||||||
|
"Supports only mxfp4_w4a16 or int4_w4a16"
|
||||||
|
)
|
||||||
|
self.w13_g_idx = w13_g_idx
|
||||||
|
self.w2_g_idx = w2_g_idx
|
||||||
|
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
|
||||||
|
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
|
||||||
|
self.is_k_full = is_k_full
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quant_type_id(self) -> int:
|
||||||
|
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
|
||||||
|
return (
|
||||||
|
scalar_types.uint4b8.id
|
||||||
|
if self.quant_config.use_int4_w4a16
|
||||||
|
else scalar_types.float4_e2m1f.id
|
||||||
|
)
|
||||||
|
|
||||||
def moe_problem_size(
|
def moe_problem_size(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
@ -533,8 +557,23 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
|
|
||||||
class MarlinExperts(MarlinExpertsBase):
|
class MarlinExperts(MarlinExpertsBase):
|
||||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
def __init__(
|
||||||
super().__init__(quant_config)
|
self,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
w13_g_idx: torch.Tensor | None = None,
|
||||||
|
w2_g_idx: torch.Tensor | None = None,
|
||||||
|
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
is_k_full: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
quant_config,
|
||||||
|
w13_g_idx,
|
||||||
|
w2_g_idx,
|
||||||
|
w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices,
|
||||||
|
is_k_full,
|
||||||
|
)
|
||||||
|
|
||||||
def supports_expert_map(self) -> bool:
|
def supports_expert_map(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -616,7 +655,7 @@ class MarlinExperts(MarlinExpertsBase):
|
|||||||
gating_output=None,
|
gating_output=None,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
quant_type_id=self.quant_type_id,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
@ -628,6 +667,11 @@ class MarlinExperts(MarlinExpertsBase):
|
|||||||
# output buffer allocation. Please refer to workspace_shapes().
|
# output buffer allocation. Please refer to workspace_shapes().
|
||||||
intermediate_cache13=workspace2,
|
intermediate_cache13=workspace2,
|
||||||
intermediate_cache2=workspace13,
|
intermediate_cache2=workspace13,
|
||||||
|
g_idx1=self.w13_g_idx,
|
||||||
|
g_idx2=self.w2_g_idx,
|
||||||
|
sort_indices1=self.w13_g_idx_sort_indices,
|
||||||
|
sort_indices2=self.w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||||
@ -650,8 +694,20 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
|||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
w13_g_idx: torch.Tensor | None = None,
|
||||||
|
w2_g_idx: torch.Tensor | None = None,
|
||||||
|
w13_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
w2_g_idx_sort_indices: torch.Tensor | None = None,
|
||||||
|
is_k_full: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(
|
||||||
|
quant_config,
|
||||||
|
w13_g_idx,
|
||||||
|
w2_g_idx,
|
||||||
|
w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices,
|
||||||
|
is_k_full,
|
||||||
|
)
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
@ -720,7 +776,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
|||||||
w1_scale=self.w1_scale,
|
w1_scale=self.w1_scale,
|
||||||
w2_scale=self.w2_scale,
|
w2_scale=self.w2_scale,
|
||||||
gating_output=None,
|
gating_output=None,
|
||||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
quant_type_id=self.quant_type_id,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
@ -728,4 +784,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
|||||||
output=output,
|
output=output,
|
||||||
intermediate_cache13=workspace13,
|
intermediate_cache13=workspace13,
|
||||||
intermediate_cache2=workspace2,
|
intermediate_cache2=workspace2,
|
||||||
|
g_idx1=self.w13_g_idx,
|
||||||
|
g_idx2=self.w2_g_idx,
|
||||||
|
sort_indices1=self.w13_g_idx_sort_indices,
|
||||||
|
sort_indices2=self.w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -35,7 +35,11 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
|||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
is_valid_flashinfer_cutlass_fused_moe,
|
is_valid_flashinfer_cutlass_fused_moe,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
|
BatchedMarlinExperts,
|
||||||
|
MarlinExperts,
|
||||||
|
fused_marlin_moe,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
WNA16_SUPPORTED_BITS,
|
WNA16_SUPPORTED_BITS,
|
||||||
WNA16_SUPPORTED_TYPES_MAP,
|
WNA16_SUPPORTED_TYPES_MAP,
|
||||||
@ -1578,7 +1582,51 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
return None
|
if self.num_bits != 4:
|
||||||
|
return None
|
||||||
|
return int4_w4a16_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
w1_zp=None,
|
||||||
|
w2_zp=None,
|
||||||
|
block_shape=[0, self.group_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.num_bits == 4, "only supporting w4"
|
||||||
|
layer.w13_weight = layer.w13_weight_packed
|
||||||
|
layer.w2_weight = layer.w2_weight_packed
|
||||||
|
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
if (
|
||||||
|
prepare_finalize.activation_format
|
||||||
|
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
):
|
||||||
|
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||||
|
assert max_num_tokens_per_rank is not None
|
||||||
|
return BatchedMarlinExperts(
|
||||||
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
w13_g_idx=layer.w13_weight_g_idx,
|
||||||
|
w2_g_idx=layer.w2_weight_g_idx,
|
||||||
|
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return MarlinExperts(
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
w13_g_idx=layer.w13_weight_g_idx,
|
||||||
|
w2_g_idx=layer.w2_weight_g_idx,
|
||||||
|
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user