mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from collections.abc import Sequence
|
|
from typing import NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def check_embeddings_close(
|
|
*,
|
|
embeddings_0_lst: Sequence[list[float]],
|
|
embeddings_1_lst: Sequence[list[float]],
|
|
name_0: str,
|
|
name_1: str,
|
|
tol: float = 1e-3,
|
|
) -> None:
|
|
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
|
|
|
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
|
zip(embeddings_0_lst, embeddings_1_lst)):
|
|
assert len(embeddings_0) == len(embeddings_1), (
|
|
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
|
|
|
|
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
|
torch.tensor(embeddings_1),
|
|
dim=0)
|
|
|
|
fail_msg = (f"Test{prompt_idx}:"
|
|
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
|
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
|
|
|
assert sim >= 1 - tol, fail_msg
|
|
|
|
|
|
def matryoshka_fy(tensor, dimensions):
|
|
tensor = torch.tensor(tensor)
|
|
tensor = tensor[..., :dimensions]
|
|
tensor = F.normalize(tensor, p=2, dim=1)
|
|
return tensor
|
|
|
|
|
|
class EmbedModelInfo(NamedTuple):
|
|
name: str
|
|
is_matryoshka: bool
|
|
matryoshka_dimensions: Optional[list[int]] = None
|
|
architecture: str = ""
|
|
enable_test: bool = True
|
|
|
|
|
|
def correctness_test(hf_model,
|
|
inputs,
|
|
vllm_outputs: Sequence[list[float]],
|
|
dimensions: Optional[int] = None):
|
|
|
|
hf_outputs = hf_model.encode(inputs)
|
|
if dimensions:
|
|
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
|
|
|
check_embeddings_close(
|
|
embeddings_0_lst=hf_outputs,
|
|
embeddings_1_lst=vllm_outputs,
|
|
name_0="hf",
|
|
name_1="vllm",
|
|
tol=1e-2,
|
|
)
|