mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com>
This commit is contained in:
parent
a463555dee
commit
bb239a730f
@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_quark.py`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||||
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
|
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||||
@ -63,3 +64,28 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
|||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
def test_quark_fp8_parity(vllm_runner):
|
||||||
|
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
|
||||||
|
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
|
||||||
|
|
||||||
|
llm_kwargs = {
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"enforce_eager": True,
|
||||||
|
"gpu_memory_utilization": 0.1
|
||||||
|
}
|
||||||
|
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||||
|
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||||
|
quark_model = (quark_handle.model.llm_engine.model_executor.
|
||||||
|
driver_worker.model_runner.model)
|
||||||
|
quark_state_dict = quark_model.state_dict()
|
||||||
|
|
||||||
|
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
fp8_state_dict = fp8_model.state_dict()
|
||||||
|
|
||||||
|
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||||
|
|
||||||
|
for key in fp8_state_dict:
|
||||||
|
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
|
||||||
|
|||||||
@ -34,21 +34,24 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
# tensor scales (thus N scales being passed to the kernel),
|
# tensor scales (thus N scales being passed to the kernel),
|
||||||
# requantize so we can always run per tensor
|
# requantize so we can always run per tensor
|
||||||
if self.qscheme == "per_tensor":
|
if self.qscheme == "per_tensor":
|
||||||
max_w_scale, weight = requantize_with_max_scale(
|
if current_platform.is_rocm():
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
logical_widths=layer.logical_widths,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_platform.is_fp8_fnuz():
|
|
||||||
input_scale = getattr(layer, 'input_scale', None)
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=layer.weight,
|
||||||
weight_scale=max_w_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=input_scale)
|
input_scale=input_scale)
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
layer.input_scale = Parameter(input_scale,
|
layer.input_scale = Parameter(input_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
max_w_scale = layer.weight_scale
|
||||||
|
weight = layer.weight
|
||||||
|
|
||||||
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=max_w_scale,
|
||||||
|
logical_widths=layer.logical_widths,
|
||||||
|
)
|
||||||
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user