mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
|
|
|
|
|
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
|
@pytest.mark.parametrize("column_major", [False, True])
|
|
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
|
@pytest.mark.parametrize("group_size", [64, 128])
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
def test_per_token_group_quant_fp8(
|
|
shape, column_major: bool, scale_ue8m0: bool, group_size: int
|
|
):
|
|
device = "cuda"
|
|
|
|
torch.manual_seed(42)
|
|
num_tokens, hidden_dim = shape
|
|
|
|
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
|
|
|
# cuda path
|
|
out_q, scale = fp8_utils.per_token_group_quant_fp8(
|
|
x,
|
|
group_size,
|
|
column_major_scales=column_major,
|
|
use_ue8m0=scale_ue8m0,
|
|
)
|
|
|
|
# triton ref
|
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
|
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
|
x,
|
|
group_size,
|
|
column_major_scales=column_major,
|
|
use_ue8m0=scale_ue8m0,
|
|
)
|
|
|
|
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
|
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
|
|
|
|
|
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
|
@pytest.mark.parametrize("group_size", [64, 128])
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
def test_per_token_group_quant_int8(shape, group_size: int):
|
|
device = "cuda"
|
|
|
|
torch.manual_seed(42)
|
|
num_tokens, hidden_dim = shape
|
|
|
|
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
|
|
|
# cuda path
|
|
out_q, scale = int8_utils.per_token_group_quant_int8(
|
|
x,
|
|
group_size,
|
|
)
|
|
|
|
# triton ref
|
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
|
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
|
x,
|
|
group_size,
|
|
)
|
|
|
|
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
|
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|