feat: Add Support GPTQ Quantization MOE on ROCM vllm serve (#21733)

This commit is contained in:
JartX 2025-08-02 03:12:19 +02:00 committed by GitHub
parent eefbf4a68b
commit 3654847db5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 5 deletions

View File

@ -761,8 +761,8 @@ def get_moe_wna16_block_config(config: dict[str,
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
num_experts: int, bit: int): num_experts: int, bit: int):
return bit == 4 and group_size in [32, 64, 128] and \ return current_platform.is_cuda() and bit == 4 and \
num_valid_tokens / num_experts <= 6 group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6
def get_default_config( def get_default_config(

View File

@ -10,10 +10,11 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method) get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
@ -110,8 +111,23 @@ class GPTQConfig(QuantizationConfig):
return cls(weight_bits, group_size, desc_act, lm_head_quantized, return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic) dynamic)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(
prefix: str) -> Optional["GPTQLinearMethod"]: self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
if isinstance(layer, FusedMoE):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config
config = {
"quant_method": "gptq",
"bits": self.weight_bits,
"group_size": self.group_size,
"sym": True, # GPTQ typically uses symmetric quantization
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)