diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 27b6ffaa67176..4fd8fc5c54202 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 5eb6bc4829adf..a9a2990ca2b53 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config( ) +def gptq_marlin_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + weight_bits: int, + group_size: int, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +): + """ + Construct a quant config for gptq marlin quantization. + """ + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) + + # Activations are NOT quantized for GPTQ (fp16/bf16) + a_shape = w_shape # Same as weight shape for alignment + + # Determine weight dtype + if weight_bits == 4: + weight_dtype = "int4" + elif weight_bits == 8: + weight_dtype = torch.int8 + else: + raise ValueError(f"Unsupported weight_bits: {weight_bits}") + + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), + ) + + def mxfp4_w4a16_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 8d1715f52f097..6e5dcfe59b2f9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # The modular kernel expects w13_weight and w2_weight, + # but GPTQ uses w13_qweight and w2_qweight + # Alias for modular kernel + layer.w13_weight = layer.w13_qweight + # Alias for modular kernel + layer.w2_weight = layer.w2_qweight + # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return None + from vllm.model_executor.layers.fused_moe.config import ( + gptq_marlin_moe_quant_config, + ) + + return gptq_marlin_moe_quant_config( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + weight_bits=self.quant_config.weight_bits, + group_size=self.quant_config.group_size, + w1_zp=getattr(layer, "w13_qzeros", None) + if not self.quant_config.is_sym + else None, + w2_zp=getattr(layer, "w2_qzeros", None) + if not self.quant_config.is_sym + else None, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + def select_gemm_impl( + self, + prepare_finalize, + layer: torch.nn.Module, + ): + """ + Select the GEMM implementation for GPTQ-Marlin MoE. + + Returns MarlinExperts configured for GPTQ quantization. + This is ONLY used when LoRA is enabled. + Without LoRA, GPTQ uses its own apply() method. + """ + # Only use modular kernels when LoRA is enabled + # Without LoRA, GPTQ's own apply() method works fine and is more efficient + if not self.moe.is_lora_enabled: + raise NotImplementedError( + "GPTQ-Marlin uses its own apply() method when LoRA is not enabled. " + "Modular kernels are only used for LoRA support." + ) + + # The modular marlin kernels do not support 8-bit weights. + if self.quant_config.weight_bits == 8: + raise NotImplementedError( + "GPTQ-Marlin kernel does not support 8-bit weights." + ) + + from vllm.model_executor.layers.fused_moe import modular_kernel as mk + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + ) + + # Ensure quant config is initialized + assert self.moe_quant_config is not None, ( + "moe_quant_config must be initialized before select_gemm_impl" + ) + + w13_g_idx = ( + getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None + ) + w2_g_idx = ( + getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None + ) + w13_g_idx_sort_indices = ( + getattr(layer, "w13_g_idx_sort_indices", None) + if self.quant_config.desc_act + else None + ) + w2_g_idx_sort_indices = ( + getattr(layer, "w2_g_idx_sort_indices", None) + if self.quant_config.desc_act + else None + ) + + # Check if using batched expert format (for Expert Parallelism) + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + # For batched format, use BatchedMarlinExperts + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + # Standard Marlin experts for GPTQ + return MarlinExperts( + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) def apply( self,