Feat Dynamic Quantization for MoE Layers in GPTQ Marlin Backend (#19395)

This commit is contained in:
Jun-Howie 2025-06-24 06:23:28 +08:00 committed by GitHub
parent a3bc76e4b5
commit dd2ccf8dde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
@ -9,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -19,7 +21,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
get_dynamic_override, get_linear_quant_method, override_config)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_make_workspace_new, marlin_moe_permute_scales,
@ -35,6 +37,29 @@ from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
def get_moe_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
moe_method_cls: type,
):
cloned_config = deepcopy(config)
if isinstance(layer, FusedMoE):
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
return UnquantizedFusedMoEMethod(layer.moe_config)
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return moe_method_cls(cloned_config)
return None
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
@ -163,7 +188,8 @@ class GPTQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return GPTQMarlinMoEMethod(self)
return get_moe_quant_method(self, layer, prefix,
GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)