gptq marlin quantization support for fused moe with lora (#30254)

Signed-off-by: Bhanu068 <voutharoja.bhanu06@gmail.com>
This commit is contained in:
Bhanu Prakash Voutharoja 2025-12-12 13:27:22 +11:00 committed by GitHub
parent f355ad5412
commit 6a6fc41c79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 146 additions and 2 deletions

View File

@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
}

View File

@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
)
def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
weight_bits: int,
group_size: int,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],

View File

@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
return gptq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if not self.quant_config.is_sym
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if not self.quant_config.is_sym
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply(
self,