FIX MOE issue in AutoRound format (#18586)

Signed-off-by: wenhuach21 <wenhua.cheng@intel.com>
This commit is contained in:
Wenhua Cheng 2025-05-24 13:01:40 +08:00 committed by GitHub
parent 45ab403a1f
commit ec82c3e388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 29 deletions

View File

@ -58,7 +58,7 @@ vLLM is fast with:
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8.
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
- Speculative decoding
- Chunked prefill

View File

@ -8,6 +8,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -74,7 +75,7 @@ class AutoRoundConfig(QuantizationConfig):
f"group_size={self.group_size}, sym={self.sym})")
@classmethod
def get_name(cls): ## use str will trigger preci issue
def get_name(cls) -> QuantizationMethods:
return "auto-round"
@classmethod
@ -142,18 +143,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix, layer.__class__.__name__, weight_bits, group_size,
sym)
if backend == "auto" or "marlin" in backend:
if isinstance(layer, FusedMoE):
use_marlin = check_moe_marlin_supports_layer(layer, group_size)
else:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits
in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size)
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP
and check_marlin_supported(
AWQ_TYPE_MAP[(weight_bits)], group_size,
not sym))
else:
use_marlin = False
if use_marlin:
@ -180,10 +181,11 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
config = {
"linear_quant_method": "awq",
"weight_bits": weight_bits,
"quant_method": "awq",
"bits": weight_bits,
"group_size": group_size,
"zero_point": not sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
@ -213,18 +215,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix, layer.__class__.__name__, weight_bits, group_size,
sym)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)],
group_size,
has_zp=not sym))
if isinstance(layer, FusedMoE):
use_marlin = check_moe_marlin_supports_layer(layer, group_size)
else:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)],
group_size,
has_zp=not sym))
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size)
else:
use_marlin = False
if use_marlin:
@ -251,11 +253,11 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
config = {
"linear_quant_method": "gptq",
"weight_bits": weight_bits,
"quant_method": "gptq",
"bits": weight_bits,
"group_size": group_size,
"sym": sym,
"lm_head_quantized": False,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)