mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
Feat Dynamic Quantization for MoE Layers in GPTQ Marlin Backend (#19395)
This commit is contained in:
parent
a3bc76e4b5
commit
dd2ccf8dde
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -9,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -19,7 +21,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
get_dynamic_override, get_linear_quant_method, override_config)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_make_workspace_new, marlin_moe_permute_scales,
|
||||
@ -35,6 +37,29 @@ from vllm.scalar_type import scalar_types
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
moe_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix) == False: # noqa: E712
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return moe_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
@ -163,7 +188,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return get_moe_quant_method(self, layer, prefix,
|
||||
GPTQMarlinMoEMethod)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user