mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:15:57 +08:00
517 lines
20 KiB
Python
517 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEQuantConfig,
|
|
int4_w4a16_moe_quant_config,
|
|
int8_w8a16_moe_quant_config,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
FusedMoE,
|
|
FusedMoEConfig,
|
|
FusedMoEMethodBase,
|
|
FusedMoeWeightScaleSupported,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
|
UnquantizedFusedMoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
check_marlin_supports_layer,
|
|
)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
class MoeWNA16Config(QuantizationConfig):
|
|
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
|
|
|
def __init__(
|
|
self,
|
|
linear_quant_method: str,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
has_zp: bool,
|
|
lm_head_quantized: bool,
|
|
modules_to_not_convert: list[str] | None,
|
|
full_config: dict[str, Any],
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight_bits = weight_bits
|
|
self.group_size = group_size
|
|
self.has_zp = has_zp
|
|
self.bit8_pack_factor = 8 // self.weight_bits
|
|
self.lm_head_quantized = lm_head_quantized
|
|
self.linear_quant_method = linear_quant_method
|
|
self.full_config = full_config
|
|
self.use_marlin = False
|
|
# Avoid circular import
|
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
|
|
|
if self.linear_quant_method == "gptq":
|
|
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
|
|
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
|
capability_tuple = current_platform.get_device_capability()
|
|
device_capability = (
|
|
-1 if capability_tuple is None else capability_tuple.to_int()
|
|
)
|
|
awq_min_capability = AWQConfig.get_min_capability()
|
|
if device_capability < awq_min_capability:
|
|
raise ValueError(
|
|
"The quantization method moe_wna16 + awq is not supported "
|
|
"for the current GPU. "
|
|
f"Minimum capability: {awq_min_capability}. "
|
|
f"Current capability: {device_capability}."
|
|
)
|
|
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(full_config)
|
|
else:
|
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
|
|
|
if modules_to_not_convert is None:
|
|
self.modules_to_not_convert = []
|
|
else:
|
|
self.modules_to_not_convert = modules_to_not_convert
|
|
|
|
@classmethod
|
|
def get_name(cls) -> QuantizationMethods:
|
|
return "moe_wna16"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.bfloat16, torch.half]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 70
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
|
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
|
if linear_quant_method == "gptq":
|
|
has_zp = not cls.get_from_keys(config, ["sym"])
|
|
modules_to_not_convert = []
|
|
elif linear_quant_method in ("awq", "awq_marlin"):
|
|
has_zp = cls.get_from_keys(config, ["zero_point"])
|
|
modules_to_not_convert = cls.get_from_keys_or(
|
|
config, ["modules_to_not_convert"], None
|
|
)
|
|
else:
|
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
|
|
|
return cls(
|
|
linear_quant_method,
|
|
weight_bits,
|
|
group_size,
|
|
has_zp,
|
|
lm_head_quantized,
|
|
modules_to_not_convert,
|
|
config,
|
|
)
|
|
|
|
@classmethod
|
|
def override_quantization_method(
|
|
cls, hf_quant_cfg, user_quant
|
|
) -> QuantizationMethods | None:
|
|
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
|
if can_convert and user_quant == "moe_wna16":
|
|
return cls.get_name()
|
|
return None
|
|
|
|
@classmethod
|
|
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
|
# Extract data from quant config.
|
|
quant_method = quant_config.get("quant_method", "").lower()
|
|
num_bits = quant_config.get("bits")
|
|
desc_act = quant_config.get("desc_act")
|
|
|
|
capability_tuple = current_platform.get_device_capability()
|
|
device_capability = (
|
|
-1 if capability_tuple is None else capability_tuple.to_int()
|
|
)
|
|
# Avoid circular import
|
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
|
|
awq_min_capability = AWQConfig.get_min_capability()
|
|
|
|
gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
|
|
awq_compatible = (
|
|
quant_method == "awq"
|
|
and num_bits == 4
|
|
and device_capability >= awq_min_capability
|
|
)
|
|
|
|
return gptq_compatible or awq_compatible
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional["QuantizeMethodBase"]:
|
|
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
|
if isinstance(layer, FusedMoE):
|
|
return UnquantizedFusedMoEMethod(layer.moe_config)
|
|
return UnquantizedLinearMethod()
|
|
elif isinstance(layer, LinearBase):
|
|
# Avoid circular import
|
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
|
AWQMarlinConfig,
|
|
)
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
GPTQMarlinConfig,
|
|
)
|
|
|
|
if self.linear_quant_method == "gptq":
|
|
if self.use_marlin:
|
|
return GPTQMarlinConfig.from_config(
|
|
self.full_config
|
|
).get_quant_method(layer, prefix)
|
|
else:
|
|
return GPTQConfig.from_config(self.full_config).get_quant_method(
|
|
layer, prefix
|
|
)
|
|
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
|
if self.use_marlin and check_marlin_supports_layer(
|
|
layer, self.group_size
|
|
):
|
|
return AWQMarlinConfig.from_config(
|
|
self.full_config
|
|
).get_quant_method(layer, prefix)
|
|
else:
|
|
return AWQConfig.from_config(self.full_config).get_quant_method(
|
|
layer, prefix
|
|
)
|
|
else:
|
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
|
elif isinstance(layer, FusedMoE):
|
|
return MoeWNA16Method(self, layer.moe_config)
|
|
return None
|
|
|
|
|
|
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
|
|
|
|
|
class MoeWNA16Method(FusedMoEMethodBase):
|
|
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
|
|
|
Args:
|
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None:
|
|
super().__init__(moe)
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
layer.quant_config = self.quant_config
|
|
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
|
group_size = self.quant_config.group_size
|
|
group_size_div_factor = 1
|
|
|
|
# make intermediate_size and hidden_size divisible by group_size
|
|
# we reduce the group size to ensure that
|
|
# and we would repeat the loaded_weight later
|
|
while intermediate_size_per_partition % group_size or hidden_size % group_size:
|
|
group_size = group_size // 2
|
|
group_size_div_factor *= 2
|
|
assert group_size >= 32
|
|
layer.group_size = group_size
|
|
layer.group_size_div_factor = group_size_div_factor
|
|
|
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
|
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
|
|
|
|
assert "weight_loader" in extra_weight_attrs
|
|
weight_loader = extra_weight_attrs["weight_loader"]
|
|
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
|
|
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
|
|
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_qweight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size // bit8_pack_factor,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_qweight", w13_qweight)
|
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
w2_qweight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition // bit8_pack_factor,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_qweight", w2_qweight)
|
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
|
|
|
w13_scales = torch.nn.Parameter(
|
|
torch.zeros(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size // group_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_scales", w13_scales)
|
|
set_weight_attrs(w13_scales, extra_weight_attrs)
|
|
|
|
w2_scales = torch.nn.Parameter(
|
|
torch.zeros(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition // group_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_scales", w2_scales)
|
|
set_weight_attrs(w2_scales, extra_weight_attrs)
|
|
|
|
if self.quant_config.has_zp:
|
|
w13_qzeros = torch.nn.Parameter(
|
|
torch.zeros(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition // bit8_pack_factor,
|
|
hidden_size // group_size,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_qzeros", w13_qzeros)
|
|
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
|
|
|
w2_qzeros = torch.nn.Parameter(
|
|
torch.zeros(
|
|
num_experts,
|
|
hidden_size // bit8_pack_factor,
|
|
intermediate_size_per_partition // group_size,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_qzeros", w2_qzeros)
|
|
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
|
|
|
if self.quant_config.linear_quant_method == "gptq":
|
|
# some param are unused, but we need to init them in order to
|
|
# load weights
|
|
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
|
if not self.quant_config.has_zp:
|
|
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
|
for key in invalid_param_keys:
|
|
param = torch.nn.Parameter(
|
|
torch.empty((0,), dtype=torch.int32), requires_grad=False
|
|
)
|
|
layer.register_parameter(key, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
def get_fused_moe_quant_config(
|
|
self, layer: torch.nn.Module
|
|
) -> FusedMoEQuantConfig | None:
|
|
weight_bits = self.quant_config.weight_bits
|
|
has_zp = self.quant_config.has_zp
|
|
assert weight_bits == 4 or weight_bits == 8
|
|
config_builder = (
|
|
int4_w4a16_moe_quant_config
|
|
if weight_bits == 4
|
|
else int8_w8a16_moe_quant_config
|
|
)
|
|
|
|
return config_builder(
|
|
w1_scale=layer.w13_scales,
|
|
w2_scale=layer.w2_scales,
|
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
|
block_shape=[0, layer.group_size],
|
|
)
|
|
|
|
def apply(
|
|
self,
|
|
layer: FusedMoE,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
|
|
assert layer.activation == "silu", "Only SiLU activation is supported."
|
|
topk_weights, topk_ids, _ = layer.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
)
|
|
|
|
return fused_experts(
|
|
x,
|
|
layer.w13_qweight,
|
|
layer.w2_qweight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
|
global_num_experts=layer.global_num_experts,
|
|
expert_map=layer.expert_map,
|
|
quant_config=self.moe_quant_config,
|
|
)
|
|
|
|
@staticmethod
|
|
def get_weight_loader(layer, weight_loader):
|
|
def convert_awq_tensor(tensor, tensor_type):
|
|
# convert awq qweight/qzeros to a standard format (assume int4)
|
|
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
|
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
|
# (n // pack_factor_bit8, k // group_size)
|
|
# pack_factor_bit32 = 32 // weight_bits
|
|
# pack_factor_bit8 = 8 // weight_bits
|
|
|
|
# 0. suppose origin shape (a, b), dtype int32
|
|
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
|
size0 = tensor.size(0)
|
|
tensor = tensor.view(torch.uint8)
|
|
|
|
# 2. unpack to uint4 (only when weight_bits == 4)
|
|
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
|
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
|
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
|
|
|
# 3. change order, see
|
|
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
|
# shape -> (a, 4 * b * pack_factor_bit8)
|
|
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
|
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
|
tensor = tensor.view(size0, -1)
|
|
|
|
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
|
tensor = tensor.T.contiguous()
|
|
|
|
# 5. repack (only when weight_bits == 4)
|
|
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
|
# qzeros shape -> (4 * b, a)
|
|
|
|
if tensor_type == "qweight":
|
|
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
|
elif tensor_type == "qzeros":
|
|
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
|
return tensor
|
|
|
|
def convert_gptq_int4_qzeros(tensor):
|
|
tensor = tensor.view(torch.uint8)
|
|
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
|
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
|
tensor = tensor + 1
|
|
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
|
return tensor
|
|
|
|
def moe_wna16_weight_loader(
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
return_success: bool = False,
|
|
):
|
|
if "g_idx" in weight_name:
|
|
return False if return_success else None
|
|
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
|
return False if return_success else None
|
|
|
|
device = get_tp_group().device
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
loaded_weight = loaded_weight.to(device)
|
|
shard_size = layer.intermediate_size_per_partition
|
|
|
|
# convert gptq and awq weight to a standard format
|
|
# awq_marlin uses the same weight format as awq
|
|
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
|
|
assert layer.quant_config.weight_bits == 4
|
|
if "weight" in weight_name:
|
|
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
|
|
elif "zeros" in weight_name:
|
|
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
|
else:
|
|
loaded_weight = loaded_weight.T
|
|
elif layer.quant_config.linear_quant_method == "gptq":
|
|
assert layer.quant_config.weight_bits in [4, 8]
|
|
if "weight" in weight_name:
|
|
loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
|
|
elif "zeros" in weight_name:
|
|
# add 1 to gptq qzeros to align with awq
|
|
loaded_weight = loaded_weight.view(torch.uint8)
|
|
if layer.quant_config.weight_bits == 4:
|
|
loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
|
|
else:
|
|
loaded_weight = loaded_weight.T + 1
|
|
else:
|
|
loaded_weight = loaded_weight.T
|
|
|
|
# repeat the qzeros/scales to fit new group size
|
|
if (
|
|
layer.group_size_div_factor > 1
|
|
and "qzeros" in weight_name
|
|
or "scales" in weight_name
|
|
):
|
|
loaded_weight = loaded_weight.repeat_interleave(
|
|
layer.group_size_div_factor, 1
|
|
)
|
|
|
|
if "w13_qzeros" in weight_name:
|
|
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
|
|
tp_rank
|
|
]
|
|
if shard_id == "w1":
|
|
param.data[expert_id, : shard_size // 2] = tensor
|
|
else:
|
|
param.data[expert_id, shard_size // 2 :] = tensor
|
|
return True if return_success else None
|
|
elif "w2_qzeros" in weight_name:
|
|
param.data[expert_id] = loaded_weight.view(
|
|
loaded_weight.size(0), layer.tp_size, -1
|
|
)[:, tp_rank]
|
|
return True if return_success else None
|
|
else:
|
|
# Delegate to the original loader, passing return_success
|
|
return weight_loader(
|
|
param,
|
|
loaded_weight,
|
|
weight_name,
|
|
shard_id,
|
|
expert_id,
|
|
return_success=return_success,
|
|
)
|
|
|
|
return moe_wna16_weight_loader
|