# SPDX-License-Identifier: Apache-2.0 """Compare the embedding outputs of HF and vLLM models. Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`. """ import pytest from tests.models.embedding.utils import EmbedModelInfo from ..utils import check_embeddings_close EMBEDDING_PROMPTS = [ 'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!', 'Mexico City of Course!' ] MODELS = [ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", is_matryoshka=False, architecture="BertModel", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", is_matryoshka=False, architecture="BertModel", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", is_matryoshka=False, architecture="BertModel", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, architecture="BertModel", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", is_matryoshka=True, architecture="BertModel", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", is_matryoshka=True, architecture="XLMRobertaModel", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", is_matryoshka=True, architecture="GteModel", enable_test=True), ] @pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model_info: EmbedModelInfo, dtype: str, monkeypatch, ) -> None: if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. pytest.skip("Skipping test.") example_prompts = example_prompts + EMBEDDING_PROMPTS vllm_extra_kwargs = { "hf_overrides": { "is_matryoshka": model_info.is_matryoshka } } with hf_runner(model_info.name, dtype=dtype, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model_info.name, task="embed", dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: assert (vllm_model.model.llm_engine.model_config.is_matryoshka == model_info.is_matryoshka) if model_info.architecture: assert (model_info.architecture in vllm_model.model.llm_engine.model_config.architectures) vllm_outputs = vllm_model.encode(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, embeddings_1_lst=vllm_outputs, name_0="hf", name_1="vllm", tol=1e-2, )