From dd2ccf8ddead5cdb46613b2adeb2edcc77f426c0 Mon Sep 17 00:00:00 2001 From: Jun-Howie <62869005+Jun-Howie@users.noreply.github.com> Date: Tue, 24 Jun 2025 06:23:28 +0800 Subject: [PATCH] Feat Dynamic Quantization for MoE Layers in GPTQ Marlin Backend (#19395) --- .../layers/quantization/gptq_marlin.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index f92ebdea986da..e9b8dc3266b4a 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy from typing import Any, Callable, Optional, Union import torch @@ -9,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -19,7 +21,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) + get_dynamic_override, get_linear_quant_method, override_config) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales, @@ -35,6 +37,29 @@ from vllm.scalar_type import scalar_types logger = init_logger(__name__) +def get_moe_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + moe_method_cls: type, +): + cloned_config = deepcopy(config) + + if isinstance(layer, FusedMoE): + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + return UnquantizedFusedMoEMethod(layer.moe_config) + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return moe_method_cls(cloned_config) + return None + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -163,7 +188,8 @@ class GPTQMarlinConfig(QuantizationConfig): "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - return GPTQMarlinMoEMethod(self) + return get_moe_quant_method(self, layer, prefix, + GPTQMarlinMoEMethod) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)