[Quantization] fix attention quantization of gpt_oss model (#27334)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
xuebwang-amd 2025-11-12 01:06:00 +08:00 committed by GitHub
parent 05576df85c
commit 5a1271d83a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 4 deletions

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test attention quantization of gpt-oss model.
The qkv_proj and o_proj in self_attention can be either quantized or excluded.
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`.
"""
import importlib
import importlib.metadata
from dataclasses import dataclass
import huggingface_hub
import lm_eval
import pytest
from packaging import version
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"]
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
def has_huggingface_access(repo):
try:
huggingface_hub.list_repo_refs(repo)
return True
except huggingface_hub.errors.RepositoryNotFoundError:
return False
HF_HUB_AMD_ORG_ACCESS = all(
[has_huggingface_access(model_name) for model_name in MODEL_NAMES]
)
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class EvaluationConfig:
model_name: str
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False"
)
EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
)
@pytest.mark.parametrize("model_name", MODEL_NAMES)
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items())
def test_gpt_oss_attention_quantization(
model_name: str, task_name: str, expected_accuracy: float
):
measured_accuracy = lm_eval.simple_evaluate(
model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(),
tasks=task_name,
batch_size="auto",
)["results"][task_name]["acc,none"]
rtol = 0.05
assert (
measured_accuracy - rtol < expected_accuracy
and measured_accuracy + rtol > expected_accuracy
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"

View File

@ -190,14 +190,25 @@ class Mxfp4Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
# TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here.
logger.warning_once(
"MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod."
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError("Mxfp4 attention layer is not implemented")
# TODO: Add support for MXFP4 Attention.
logger.warning_once(
"MXFP4 attention layer is not implemented. "
"Skipping quantization for this layer."
)
return None

View File

@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
@ -207,7 +208,10 @@ class TransformerBlock(torch.nn.Module):
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(
config, prefix=f"{prefix}.attn", cache_config=cache_config
config,
prefix=f"{prefix}.attn",
quant_config=quant_config,
cache_config=cache_config,
)
self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -243,6 +247,7 @@ class GptOssModel(nn.Module):
):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
@ -254,6 +259,7 @@ class GptOssModel(nn.Module):
lambda prefix: TransformerBlock(
vllm_config,
prefix=prefix,
quant_config=self.quant_config,
),
prefix=f"{prefix}.layers",
)
@ -645,7 +651,7 @@ class GptOssModel(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={