[ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model (#24239)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Co-authored-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
xuebwang-amd 2025-11-12 01:05:22 +08:00 committed by GitHub
parent 68c09efc37
commit 05576df85c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 8 deletions

View File

@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--group_size 32
```
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations.
## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models
vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including
- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16}
- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}
Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput.
There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below.
### 1. Quantize a model using mixed precision in AMD Quark
Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:
- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
### 2. inference the quantized mixed precision model in vLLM
Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow:
```bash
lm_eval --model vllm \
--model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \
--tasks mmlu \
--batch_size auto
```

View File

@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test quark-quantized {MXFP4, FP8} mixed precision models.
Run `pytest tests/quantization/test_mixed_precision.py`.
"""
import importlib
import importlib.metadata
from dataclasses import dataclass
import lm_eval
import pytest
from packaging import version
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class EvaluationConfig:
model_name: str
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False"
)
TEST_CONFIGS = {
# Mixed-precision (AMP) model
# - Demonstrates end-to-end pipeline functionality
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72},
# Non-mixed-precision (PTQ) model
# - Reference for pipeline compatibility verification -> No conflicts or breakings
"amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": {
"arc_challenge": 0.53,
"mmlu": 0.61,
},
}
@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items())
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(),
tasks=list(accuracy_numbers.keys()),
batch_size=8,
)
rtol = 0.05
for task, expect_accuracy in accuracy_numbers.items():
measured_accuracy = results["results"][task]["acc,none"]
assert (
measured_accuracy - rtol < expect_accuracy
and measured_accuracy + rtol > expect_accuracy
), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}"

View File

@ -114,7 +114,14 @@ class QuarkConfig(QuantizationConfig):
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
if not (
kv_cache_set.issubset(layer_quant_set)
or any(
fnmatch.fnmatchcase(layer_quant, pat)
for layer_quant in list(layer_quant_set)
for pat in list(kv_cache_set)
)
):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
@ -124,10 +131,15 @@ class QuarkConfig(QuantizationConfig):
)
q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
quant_cfg
for name, quant_cfg in layer_quant_config.items()
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
if not all(
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
for q_config in q_configs
):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
@ -312,9 +324,15 @@ class QuarkConfig(QuantizationConfig):
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
def _matches_pattern(layer_name, pattern):
if "*" not in pattern:
return layer_name in pattern
return fnmatch.fnmatch(layer_name, pattern)
for name_pattern, config in layer_quant_config.items():
if _matches_pattern(layer_name, name_pattern):
return config
layer_type = cast(str, type(module))
layer_type_quant_config = cast(