mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 18:35:01 +08:00
[Feature] AWQ marlin quantization support for fused moe with lora (#30442)
Signed-off-by: princepride <wangzhipeng628@gmail.com>
This commit is contained in:
parent
8781cd6b88
commit
0e71eaa644
@ -700,6 +700,42 @@ def int4_w4afp8_moe_quant_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def awq_marlin_moe_quant_config(
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
w1_zp: torch.Tensor | None,
|
||||||
|
w2_zp: torch.Tensor | None,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
w1_bias: torch.Tensor | None = None,
|
||||||
|
w2_bias: torch.Tensor | None = None,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for awq marlin quantization.
|
||||||
|
"""
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
|
||||||
|
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
|
||||||
|
|
||||||
|
# Activations are NOT quantized for AWQ (fp16/bf16)
|
||||||
|
a_shape = w_shape # Same as weight shape for alignment
|
||||||
|
|
||||||
|
# Determine weight dtype
|
||||||
|
if weight_bits == 4:
|
||||||
|
weight_dtype = "int4"
|
||||||
|
elif weight_bits == 8:
|
||||||
|
weight_dtype = torch.int8
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
|
||||||
|
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||||
|
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||||
|
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
|
||||||
|
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def biased_moe_quant_config(
|
def biased_moe_quant_config(
|
||||||
w1_bias: torch.Tensor | None,
|
w1_bias: torch.Tensor | None,
|
||||||
w2_bias: torch.Tensor | None,
|
w2_bias: torch.Tensor | None,
|
||||||
|
|||||||
@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
intermediate_size_full = extra_weight_attrs.pop(
|
||||||
|
"intermediate_size_full", intermediate_size_per_partition
|
||||||
|
)
|
||||||
|
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
|
||||||
|
|
||||||
w13_qweight = Parameter(
|
w13_qweight = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||||
|
|
||||||
|
# The modular kernel expects w13_weight and w2_weight,
|
||||||
|
# but AWQ uses w13_qweight and w2_qweight
|
||||||
|
# Alias for modular kernel
|
||||||
|
layer.w13_weight = layer.w13_qweight
|
||||||
|
# Alias for modular kernel
|
||||||
|
layer.w2_weight = layer.w2_qweight
|
||||||
|
|
||||||
# Why does this take the intermediate size for size_k?
|
# Why does this take the intermediate size for size_k?
|
||||||
marlin_w13_scales = marlin_moe_permute_scales(
|
marlin_w13_scales = marlin_moe_permute_scales(
|
||||||
s=layer.w13_scales,
|
s=layer.w13_scales,
|
||||||
@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
return None
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
awq_marlin_moe_quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return awq_marlin_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_scales,
|
||||||
|
w2_scale=layer.w2_scales,
|
||||||
|
weight_bits=self.quant_config.weight_bits,
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||||
|
if self.quant_config.zero_point
|
||||||
|
else None,
|
||||||
|
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||||
|
if self.quant_config.zero_point
|
||||||
|
else None,
|
||||||
|
w1_bias=getattr(layer, "w13_bias", None),
|
||||||
|
w2_bias=getattr(layer, "w2_bias", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Select the GEMM implementation for AWQ-Marlin MoE.
|
||||||
|
Returns MarlinExperts configured for AWQ quantization.
|
||||||
|
This is ONLY used when LoRA is enabled.
|
||||||
|
Without LoRA, AWQ uses its own apply() method.
|
||||||
|
"""
|
||||||
|
# Only use modular kernels when LoRA is enabled
|
||||||
|
# Without LoRA, AWQ's own apply() method works fine and is more efficient
|
||||||
|
if not self.moe.is_lora_enabled:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||||
|
"Modular kernels are only used for LoRA support."
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
|
BatchedMarlinExperts,
|
||||||
|
MarlinExperts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure quant config is initialized
|
||||||
|
assert self.moe_quant_config is not None, (
|
||||||
|
"moe_quant_config must be initialized before select_gemm_impl"
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_g_idx = getattr(layer, "w13_g_idx", None)
|
||||||
|
w2_g_idx = getattr(layer, "w2_g_idx", None)
|
||||||
|
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
|
||||||
|
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
|
||||||
|
|
||||||
|
# Check if using batched expert format (for Expert Parallelism)
|
||||||
|
if (
|
||||||
|
prepare_finalize.activation_format
|
||||||
|
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
):
|
||||||
|
# For batched format, use BatchedMarlinExperts
|
||||||
|
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||||
|
assert max_num_tokens_per_rank is not None
|
||||||
|
return BatchedMarlinExperts(
|
||||||
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
w13_g_idx=w13_g_idx,
|
||||||
|
w2_g_idx=w2_g_idx,
|
||||||
|
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard Marlin experts for AWQ
|
||||||
|
return MarlinExperts(
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
w13_g_idx=w13_g_idx,
|
||||||
|
w2_g_idx=w2_g_idx,
|
||||||
|
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||||
|
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user