diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f35cafa0f77dc..5eb6bc4829adf 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -700,6 +700,42 @@ def int4_w4afp8_moe_quant_config( ) +def awq_marlin_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: torch.Tensor | None, + w2_zp: torch.Tensor | None, + weight_bits: int, + group_size: int, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for awq marlin quantization. + """ + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) + + # Activations are NOT quantized for AWQ (fp16/bf16) + a_shape = w_shape # Same as weight shape for alignment + + # Determine weight dtype + if weight_bits == 4: + weight_dtype = "int4" + elif weight_bits == 8: + weight_dtype = torch.int8 + else: + raise ValueError(f"Unsupported weight_bits: {weight_bits}") + + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), + ) + + def biased_moe_quant_config( w1_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 16aa4f1e22698..3ed15ed7dd422 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): } ) + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full", intermediate_size_per_partition + ) + self.is_k_full = intermediate_size_per_partition == intermediate_size_full + w13_qweight = Parameter( torch.empty( num_experts, @@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # The modular kernel expects w13_weight and w2_weight, + # but AWQ uses w13_qweight and w2_qweight + # Alias for modular kernel + layer.w13_weight = layer.w13_qweight + # Alias for modular kernel + layer.w2_weight = layer.w2_qweight + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return None + from vllm.model_executor.layers.fused_moe.config import ( + awq_marlin_moe_quant_config, + ) + + return awq_marlin_moe_quant_config( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + weight_bits=self.quant_config.weight_bits, + group_size=self.quant_config.group_size, + w1_zp=getattr(layer, "w13_qzeros", None) + if self.quant_config.zero_point + else None, + w2_zp=getattr(layer, "w2_qzeros", None) + if self.quant_config.zero_point + else None, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + def select_gemm_impl( + self, + prepare_finalize, + layer: torch.nn.Module, + ): + """ + Select the GEMM implementation for AWQ-Marlin MoE. + Returns MarlinExperts configured for AWQ quantization. + This is ONLY used when LoRA is enabled. + Without LoRA, AWQ uses its own apply() method. + """ + # Only use modular kernels when LoRA is enabled + # Without LoRA, AWQ's own apply() method works fine and is more efficient + if not self.moe.is_lora_enabled: + raise NotImplementedError( + "AWQ-Marlin uses its own apply() method when LoRA is not enabled. " + "Modular kernels are only used for LoRA support." + ) + + from vllm.model_executor.layers.fused_moe import modular_kernel as mk + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + ) + + # Ensure quant config is initialized + assert self.moe_quant_config is not None, ( + "moe_quant_config must be initialized before select_gemm_impl" + ) + + w13_g_idx = getattr(layer, "w13_g_idx", None) + w2_g_idx = getattr(layer, "w2_g_idx", None) + w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None) + w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None) + + # Check if using batched expert format (for Expert Parallelism) + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + # For batched format, use BatchedMarlinExperts + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + # Standard Marlin experts for AWQ + return MarlinExperts( + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) def apply( self,