mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com>
This commit is contained in:
parent
a463555dee
commit
bb239a730f
@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_quark.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
@ -63,3 +64,28 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def test_quark_fp8_parity(vllm_runner):
|
||||
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
|
||||
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
|
||||
|
||||
llm_kwargs = {
|
||||
"tensor_parallel_size": 1,
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.1
|
||||
}
|
||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||
quark_model = (quark_handle.model.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model)
|
||||
quark_state_dict = quark_model.state_dict()
|
||||
|
||||
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
fp8_state_dict = fp8_model.state_dict()
|
||||
|
||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||
|
||||
for key in fp8_state_dict:
|
||||
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
|
||||
|
||||
@ -34,21 +34,24 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.qscheme == "per_tensor":
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
if current_platform.is_rocm():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user