mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 14:05:49 +08:00
[Quantization] fix attention quantization of gpt_oss model (#27334)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
parent
05576df85c
commit
5a1271d83a
80
tests/models/quantization/test_gpt_oss_attn_quantization.py
Normal file
80
tests/models/quantization/test_gpt_oss_attn_quantization.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Test attention quantization of gpt-oss model.
|
||||||
|
The qkv_proj and o_proj in self_attention can be either quantized or excluded.
|
||||||
|
|
||||||
|
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import lm_eval
|
||||||
|
import pytest
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"]
|
||||||
|
|
||||||
|
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
|
||||||
|
importlib.metadata.version("amd-quark")
|
||||||
|
) >= version.parse("0.8.99")
|
||||||
|
|
||||||
|
|
||||||
|
def has_huggingface_access(repo):
|
||||||
|
try:
|
||||||
|
huggingface_hub.list_repo_refs(repo)
|
||||||
|
return True
|
||||||
|
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
HF_HUB_AMD_ORG_ACCESS = all(
|
||||||
|
[has_huggingface_access(model_name) for model_name in MODEL_NAMES]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCase:
|
||||||
|
model_id: str
|
||||||
|
tp: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationConfig:
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
def get_model_args(self) -> str:
|
||||||
|
return (
|
||||||
|
f"pretrained={self.model_name},"
|
||||||
|
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not HF_HUB_AMD_ORG_ACCESS,
|
||||||
|
reason="Read access to huggingface.co/amd is required for this test.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("model_name", MODEL_NAMES)
|
||||||
|
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items())
|
||||||
|
def test_gpt_oss_attention_quantization(
|
||||||
|
model_name: str, task_name: str, expected_accuracy: float
|
||||||
|
):
|
||||||
|
measured_accuracy = lm_eval.simple_evaluate(
|
||||||
|
model="vllm",
|
||||||
|
model_args=EvaluationConfig(model_name).get_model_args(),
|
||||||
|
tasks=task_name,
|
||||||
|
batch_size="auto",
|
||||||
|
)["results"][task_name]["acc,none"]
|
||||||
|
|
||||||
|
rtol = 0.05
|
||||||
|
assert (
|
||||||
|
measured_accuracy - rtol < expected_accuracy
|
||||||
|
and measured_accuracy + rtol > expected_accuracy
|
||||||
|
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"
|
||||||
@ -190,14 +190,25 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
fused_mapping=self.packed_modules_mapping,
|
fused_mapping=self.packed_modules_mapping,
|
||||||
):
|
):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
# TODO: Add support for MXFP4 Linear Method.
|
||||||
|
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
|
||||||
|
# if you are interested in enabling MXFP4 here.
|
||||||
|
logger.warning_once(
|
||||||
|
"MXFP4 linear layer is not implemented - falling back to "
|
||||||
|
"UnquantizedLinearMethod."
|
||||||
|
)
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
return IpexMxfp4MoEMethod(layer.moe_config)
|
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||||
else:
|
else:
|
||||||
return Mxfp4MoEMethod(layer.moe_config)
|
return Mxfp4MoEMethod(layer.moe_config)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
# TODO: Add support for MXFP4 Attention.
|
||||||
|
logger.warning_once(
|
||||||
|
"MXFP4 attention layer is not implemented. "
|
||||||
|
"Skipping quantization for this layer."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -207,7 +208,10 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
self.layer_idx = extract_layer_index(prefix)
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
self.attn = OAIAttention(
|
self.attn = OAIAttention(
|
||||||
config, prefix=f"{prefix}.attn", cache_config=cache_config
|
config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
quant_config=quant_config,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
|
self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
@ -243,6 +247,7 @@ class GptOssModel(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.config.hidden_size = self.config.hidden_size
|
self.config.hidden_size = self.config.hidden_size
|
||||||
self.embedding = VocabParallelEmbedding(
|
self.embedding = VocabParallelEmbedding(
|
||||||
@ -254,6 +259,7 @@ class GptOssModel(nn.Module):
|
|||||||
lambda prefix: TransformerBlock(
|
lambda prefix: TransformerBlock(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
quant_config=self.quant_config,
|
||||||
),
|
),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
@ -645,7 +651,7 @@ class GptOssModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_substr={
|
orig_to_new_substr={
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user