# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import patch import pytest import torch from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils @pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) @pytest.mark.parametrize("column_major", [False, True]) @pytest.mark.parametrize("scale_ue8m0", [False, True]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_per_token_group_quant_fp8( shape, column_major: bool, scale_ue8m0: bool, group_size: int ): device = "cuda" torch.manual_seed(42) num_tokens, hidden_dim = shape x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = fp8_utils.per_token_group_quant_fp8( x, group_size, column_major_scales=column_major, use_ue8m0=scale_ue8m0, ) # triton ref with patch("vllm.platforms.current_platform.is_cuda", return_value=False): ref_q, ref_s = fp8_utils.per_token_group_quant_fp8( x, group_size, column_major_scales=column_major, use_ue8m0=scale_ue8m0, ) assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) @pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_per_token_group_quant_int8(shape, group_size: int): device = "cuda" torch.manual_seed(42) num_tokens, hidden_dim = shape x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = int8_utils.per_token_group_quant_int8( x, group_size, ) # triton ref with patch("vllm.platforms.current_platform.is_cuda", return_value=False): ref_q, ref_s = int8_utils.per_token_group_quant_int8( x, group_size, ) assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)