mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
27 lines
849 B
Python
27 lines
849 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
from compressed_tensors.transform import deterministic_hadamard_matrix
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
|
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
|
|
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
|
|
x = torch.eye(hidden_dim, dtype=dtype, device=device)
|
|
hadamard = deterministic_hadamard_matrix(
|
|
hidden_dim, dtype=torch.float64, device="cuda"
|
|
) / math.sqrt(hidden_dim)
|
|
|
|
y = ops.hadacore_transform(x.clone())
|
|
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
|
|
assert torch.allclose(y, y_true)
|
|
|
|
y = ops.hadacore_transform(y)
|
|
assert torch.allclose(y, x)
|