Update CT WNA16MarlinMoE integration (#16666)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-09 11:19:45 -06:00 committed by GitHub
parent 5c4c08f6f1
commit 22481fbfa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import enum
from enum import Enum
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from compressed_tensors import CompressionFormat
@ -14,9 +14,12 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -54,18 +57,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if ((layer.local_num_experts >= 16
or layer.params_dtype != torch.float16) and
weight_quant.actorder not in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)
and weight_quant.strategy in QuantizationStrategy.GROUP):
# Prefer to use the MarlinMoE kernel when it is supported.
if not check_moe_marlin_supports_layer(layer,
weight_quant.group_size):
if (weight_quant.strategy in QuantizationStrategy.GROUP and
weight_quant.actorder in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)):
raise ValueError(
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu"):
@ -705,15 +709,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
@ -837,50 +838,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.marlin_state = GPTQMarlinState.REPACK
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:,
scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
size_n: int, group_size: int,
num_bits: int):
num_experts = s.shape[0]
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n,
group_size, num_bits)
return output
size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device
@ -938,7 +895,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w13_weight_packed", marlin_w13_qweight)
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
@ -946,25 +903,25 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w2_weight_packed", marlin_w2_qweight)
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
size_k13,
layer.w13_weight_scale.shape[2],
self.group_size,
self.num_bits,
s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size,
)
replace_tensor("w13_weight_scale", marlin_w13_scales)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
layer.w2_weight_scale.shape[1] *
s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1] *
(self.group_size if self.group_size != -1 else self.packed_factor),
size_k2,
self.group_size,
self.num_bits,
size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size,
)
replace_tensor("w2_weight_scale", marlin_w2_scales)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4)
def apply(
self,
@ -985,10 +942,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
@ -1015,11 +968,14 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
workspace=layer.workspace,
is_k_full=self.is_k_full)
@ -1203,7 +1159,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -1223,6 +1179,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,