mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:35:28 +08:00
41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.models.utils import check_embeddings_close
|
|
from vllm.utils.serial_utils import (
|
|
EMBED_DTYPE_TO_TORCH_DTYPE,
|
|
ENDIANNESS,
|
|
binary2tensor,
|
|
tensor2binary,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
|
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
|
@torch.inference_mode()
|
|
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
|
for i in range(10):
|
|
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
|
shape = tensor.shape
|
|
binary = tensor2binary(tensor, embed_dtype, endianness)
|
|
new_tensor = binary2tensor(binary, shape, embed_dtype, endianness).to(
|
|
torch.float32
|
|
)
|
|
|
|
if embed_dtype in ["float32", "float16"]:
|
|
torch.testing.assert_close(tensor, new_tensor, atol=0.001, rtol=0.001)
|
|
elif embed_dtype == "bfloat16":
|
|
torch.testing.assert_close(tensor, new_tensor, atol=0.01, rtol=0.01)
|
|
else: # for fp8
|
|
torch.testing.assert_close(tensor, new_tensor, atol=0.1, rtol=0.1)
|
|
|
|
check_embeddings_close(
|
|
embeddings_0_lst=tensor.view(1, -1),
|
|
embeddings_1_lst=new_tensor.view(1, -1),
|
|
name_0="gt",
|
|
name_1="new",
|
|
tol=1e-2,
|
|
)
|