mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 18:35:01 +08:00
[Feature] AWQ marlin quantization support for fused moe with lora (#30442)
Signed-off-by: princepride <wangzhipeng628@gmail.com>
This commit is contained in:
parent
8781cd6b88
commit
0e71eaa644
@ -700,6 +700,42 @@ def int4_w4afp8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for awq marlin quantization.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
|
||||
|
||||
# Activations are NOT quantized for AWQ (fp16/bf16)
|
||||
a_shape = w_shape # Same as weight shape for alignment
|
||||
|
||||
# Determine weight dtype
|
||||
if weight_bits == 4:
|
||||
weight_dtype = "int4"
|
||||
elif weight_bits == 8:
|
||||
weight_dtype = torch.int8
|
||||
else:
|
||||
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
|
||||
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
|
||||
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
|
||||
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
|
||||
@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
}
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full", intermediate_size_per_partition
|
||||
)
|
||||
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
|
||||
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but AWQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
awq_marlin_moe_quant_config,
|
||||
)
|
||||
|
||||
return awq_marlin_moe_quant_config(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
group_size=self.quant_config.group_size,
|
||||
w1_zp=getattr(layer, "w13_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w2_zp=getattr(layer, "w2_qzeros", None)
|
||||
if self.quant_config.zero_point
|
||||
else None,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Select the GEMM implementation for AWQ-Marlin MoE.
|
||||
Returns MarlinExperts configured for AWQ quantization.
|
||||
This is ONLY used when LoRA is enabled.
|
||||
Without LoRA, AWQ uses its own apply() method.
|
||||
"""
|
||||
# Only use modular kernels when LoRA is enabled
|
||||
# Without LoRA, AWQ's own apply() method works fine and is more efficient
|
||||
if not self.moe.is_lora_enabled:
|
||||
raise NotImplementedError(
|
||||
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
|
||||
"Modular kernels are only used for LoRA support."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
# Ensure quant config is initialized
|
||||
assert self.moe_quant_config is not None, (
|
||||
"moe_quant_config must be initialized before select_gemm_impl"
|
||||
)
|
||||
|
||||
w13_g_idx = getattr(layer, "w13_g_idx", None)
|
||||
w2_g_idx = getattr(layer, "w2_g_idx", None)
|
||||
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
|
||||
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
|
||||
|
||||
# Check if using batched expert format (for Expert Parallelism)
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
# For batched format, use BatchedMarlinExperts
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
else:
|
||||
# Standard Marlin experts for AWQ
|
||||
return MarlinExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
w13_g_idx=w13_g_idx,
|
||||
w2_g_idx=w2_g_idx,
|
||||
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
|
||||
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user