mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 05:33:44 +08:00
Update CT WNA16MarlinMoE integration (#16666)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
5c4c08f6f1
commit
22481fbfa3
@ -2,7 +2,7 @@
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
@ -14,9 +14,12 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -54,18 +57,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
# Prefer to use the non-marlin kernel when:
|
||||
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
|
||||
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
|
||||
# 3. Actorder is not group/dynamic (g_idx is unsupported)
|
||||
# 4. Scaled are grouped (channelwise is unsupported)
|
||||
if ((layer.local_num_experts >= 16
|
||||
or layer.params_dtype != torch.float16) and
|
||||
weight_quant.actorder not in (ActivationOrdering.GROUP,
|
||||
ActivationOrdering.DYNAMIC)
|
||||
and weight_quant.strategy in QuantizationStrategy.GROUP):
|
||||
# Prefer to use the MarlinMoE kernel when it is supported.
|
||||
if not check_moe_marlin_supports_layer(layer,
|
||||
weight_quant.group_size):
|
||||
if (weight_quant.strategy in QuantizationStrategy.GROUP and
|
||||
weight_quant.actorder in (ActivationOrdering.GROUP,
|
||||
ActivationOrdering.DYNAMIC)):
|
||||
raise ValueError(
|
||||
"WNA16MoE is not supported with actorder=group/dynamic."
|
||||
)
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
else:
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
and layer.activation == "silu"):
|
||||
@ -705,15 +709,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full")
|
||||
|
||||
@ -837,50 +838,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
def get_scale_perms(num_bits: int):
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
group_size: int, num_bits: int):
|
||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:,
|
||||
scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
return s
|
||||
|
||||
def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
|
||||
size_n: int, group_size: int,
|
||||
num_bits: int):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
|
||||
device=s.device,
|
||||
dtype=s.dtype)
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n,
|
||||
group_size, num_bits)
|
||||
return output
|
||||
|
||||
size_k2 = layer.w2_weight_packed.shape[2]
|
||||
size_k13 = layer.w13_weight_packed.shape[2]
|
||||
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
|
||||
@ -938,7 +895,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
@ -946,25 +903,25 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
||||
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
layer.w13_weight_scale,
|
||||
size_k13,
|
||||
layer.w13_weight_scale.shape[2],
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
s=layer.w13_weight_scale,
|
||||
size_k=layer.w13_weight_packed.shape[2],
|
||||
size_n=layer.w13_weight_scale.shape[2],
|
||||
group_size=self.group_size,
|
||||
)
|
||||
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
||||
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight_scale.shape[1] *
|
||||
s=layer.w2_weight_scale,
|
||||
size_k=layer.w2_weight_scale.shape[1] *
|
||||
(self.group_size if self.group_size != -1 else self.packed_factor),
|
||||
size_k2,
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
size_n=layer.w2_weight_scale.shape[2],
|
||||
group_size=self.group_size,
|
||||
)
|
||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
||||
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -985,10 +942,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for "
|
||||
@ -1015,11 +968,14 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.num_bits,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
|
||||
@ -1203,7 +1159,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -1223,6 +1179,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int4_w4a16=self.num_bits == 4,
|
||||
use_int8_w8a16=self.num_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user