[Feature] AWQ marlin quantization support for fused moe with lora (#30442)

Signed-off-by: princepride <wangzhipeng628@gmail.com>
This commit is contained in:
汪志鹏 2025-12-12 02:03:32 +08:00 committed by GitHub
parent 8781cd6b88
commit 0e71eaa644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 130 additions and 1 deletions

View File

@ -700,6 +700,42 @@ def int4_w4afp8_moe_quant_config(
) )
def awq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: torch.Tensor | None,
w2_zp: torch.Tensor | None,
weight_bits: int,
group_size: int,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for awq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for AWQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def biased_moe_quant_config( def biased_moe_quant_config(
w1_bias: torch.Tensor | None, w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None, w2_bias: torch.Tensor | None,

View File

@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
} }
) )
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full", intermediate_size_per_partition
)
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
w13_qweight = Parameter( w13_qweight = Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but AWQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Why does this take the intermediate size for size_k? # Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return None from vllm.model_executor.layers.fused_moe.config import (
awq_marlin_moe_quant_config,
)
return awq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if self.quant_config.zero_point
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if self.quant_config.zero_point
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for AWQ-Marlin MoE.
Returns MarlinExperts configured for AWQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, AWQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, AWQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = getattr(layer, "w13_g_idx", None)
w2_g_idx = getattr(layer, "w2_g_idx", None)
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for AWQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply( def apply(
self, self,