diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py new file mode 100644 index 0000000000000..ab71fbe73e6aa --- /dev/null +++ b/examples/offline_inference/embed_matryoshka_fy.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace + +from vllm import LLM, EngineArgs, PoolingParams +from vllm.utils import FlexibleArgumentParser + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Follow the white rabbit.", # English + "Sigue al conejo blanco.", # Spanish + "Suis le lapin blanc.", # French + "跟着白兔走。", # Chinese + "اتبع الأرنب الأبيض.", # Arabic + "Folge dem weißen Kaninchen.", # German + ] + + # Create an LLM. + # You should pass task="embed" for embedding models + model = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32)) + + # Print the outputs. + print("\nGenerated Outputs:") + print("-" * 60) + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + embeds_trimmed = ((str(embeds[:16])[:-1] + + ", ...]") if len(embeds) > 16 else embeds) + print(f"Prompt: {prompt!r} \n" + f"Embeddings: {embeds_trimmed} " + f"(size={len(embeds)})") + print("-" * 60) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + args = parser.parse_args() + main(args) diff --git a/tests/conftest.py b/tests/conftest.py index c5d393907ec8c..69447d3c474d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -960,19 +960,19 @@ class VllmRunner: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def encode( - self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - ) -> list[list[float]]: + def encode(self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs) -> list[list[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.model.embed(inputs) + req_outputs = self.model.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] def score( diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py new file mode 100644 index 0000000000000..79d43a2231f82 --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. +""" + +from typing import NamedTuple + +import openai +import pytest + +from vllm.entrypoints.openai.protocol import EmbeddingResponse + +from ...utils import RemoteOpenAIServer + + +class ModelInfo(NamedTuple): + name: str + is_matryoshka: bool + + +MODELS = [ + ModelInfo(name="BAAI/bge-m3", is_matryoshka=False), + ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), +] + +input_texts = [ + "The chef prepared a delicious meal.", +] * 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", MODELS) +async def test_validating_dimensions(model: ModelInfo): + args = [ + "--task", + "embed", + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--enforce-eager", + "--max-model-len", + "512", + "--trust_remote_code" + ] + with RemoteOpenAIServer(model.name, args) as remote_server: + client = remote_server.get_async_client() + + async def make_request(dimensions): + embedding_response = await client.embeddings.create( + model=model.name, + input=input_texts, + dimensions=dimensions, + encoding_format="float", + ) + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) > 0 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens > 0 + assert embeddings.usage.total_tokens > 0 + + if dimensions is not None: + assert len(embeddings.data[0].embedding) == dimensions + + if model.is_matryoshka: + for dimensions in [None, 16]: + await make_request(dimensions) + + with pytest.raises(openai.BadRequestError): + for dimensions in [-1]: + await make_request(dimensions) + + else: + for dimensions in [None]: + await make_request(dimensions) + + with pytest.raises(openai.BadRequestError): + for dimensions in [-1, 16]: + await make_request(dimensions) diff --git a/tests/models/embedding/language/test_jina.py b/tests/models/embedding/language/test_jina.py index 2a3eab02ddd9e..881d0a75b1584 100644 --- a/tests/models/embedding/language/test_jina.py +++ b/tests/models/embedding/language/test_jina.py @@ -8,7 +8,8 @@ import math import pytest -from tests.models.embedding.utils import check_embeddings_close +from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy +from vllm import PoolingParams SCORING_MODELS = [ "jinaai/jina-reranker-v2-base-multilingual", # Roberta @@ -126,3 +127,40 @@ def test_embeddings( name_1="vllm", tol=1e-2, ) + + +@pytest.mark.parametrize("model", EMBEDDING_MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dimensions", [16, 32]) +def test_matryoshka( + hf_runner, + vllm_runner, + model, + dtype: str, + dimensions: int, + monkeypatch, +) -> None: + + example_prompts = EMBEDDING_PROMPTS + + with hf_runner( + model, + dtype=dtype, + is_sentence_transformer=True, + ) as hf_model: + hf_outputs = hf_model.encode(example_prompts, task="text-matching") + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + with vllm_runner(model, task="embed", dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.encode( + example_prompts, + pooling_params=PoolingParams(dimensions=dimensions)) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index bef85eaf372f1..5aeeb51785402 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -30,3 +30,10 @@ def check_embeddings_close( f"\n{name_1}:\t{embeddings_1[:16]!r}") assert sim >= 1 - tol, fail_msg + + +def matryoshka_fy(tensor, dimensions): + tensor = torch.tensor(tensor) + tensor = tensor[..., :dimensions] + tensor = F.normalize(tensor, p=2, dim=1) + return tensor diff --git a/vllm/config.py b/vllm/config.py index b466b765d7749..d3e224a6d8346 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -583,6 +583,15 @@ class ModelConfig: if getattr(user_config, k) is None: setattr(user_config, k, v) + if self.is_matryoshka: + if user_config.normalize is None: + user_config.normalize = True + elif not user_config.normalize: + raise ValueError( + "`normalize` must be enabled (set to True) " + "for models that are compatible with " + "Matryoshka Representation.") + return user_config return None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 70bb73f482c86..a707087a2e286 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -921,6 +921,11 @@ class LLM: if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() + elif isinstance(pooling_params, PoolingParams): + pooling_params.verify(self.llm_engine.model_config) + else: + for pooling_param in pooling_params: + pooling_param.verify(self.llm_engine.model_config) self._validate_and_add_requests( prompts=parsed_prompts, @@ -939,6 +944,8 @@ class LLM: /, *, use_tqdm: bool = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[EmbeddingRequestOutput]: @@ -953,6 +960,8 @@ class LLM: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See :class:`~vllm.inputs.PromptType` for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for @@ -968,6 +977,7 @@ class LLM: items = self.encode(prompts, use_tqdm=use_tqdm, + pooling_params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cbd5f6e566b30..4639b4cea06b7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # doc: end-embedding-extra-params def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions, + additional_data=self.additional_data) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): return data def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions, + additional_data=self.additional_data) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 0ee58672631d0..ba960de17cab3 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing): return error_check_ret encoding_format = request.encoding_format - if request.dimensions is not None: - return self.create_error_response( - "dimensions is currently not supported") model_name = self._get_model_name(request.model) request_id = f"embd-{self._base_request_id(raw_request)}" @@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing): "greater than max_model_len." " Please, select a smaller truncation size.") + pooling_params = request.to_pooling_params() + + try: + pooling_params.verify(self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + try: ( lora_request, @@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing): # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: - pooling_params = request.to_pooling_params() - for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 0012636ef9ffc..3f6ab64e4fa91 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -97,7 +97,7 @@ class SimplePooler(nn.Module): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data) + pooled_data = self.head(pooled_data, pooling_metadata) pooled_outputs = [self.build_output(data) for data in pooled_data] return PoolerOutput(outputs=pooled_outputs) @@ -217,14 +217,28 @@ class PoolerHead(nn.Module): self.normalize = normalize self.softmax = softmax - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]): + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + + dimensions_list = [ + pooling_param.dimensions + for _, pooling_param in pooling_metadata.seq_groups + ] + if any(d is not None for d in dimensions_list): + # change the output dimension + assert len(pooled_data) == len(dimensions_list) + pooled_data = [ + vecs if d is None else vecs[..., :d] + for vecs, d in zip(pooled_data, dimensions_list) + ] + if self.normalize: if isinstance(pooled_data, list): pooled_data = [ - F.normalize(data, p=2, dim=1) for data in pooled_data + F.normalize(data, p=2, dim=-1) for data in pooled_data ] else: - pooled_data = F.normalize(pooled_data, p=2, dim=1) + pooled_data = F.normalize(pooled_data, p=2, dim=-1) if self.softmax: if isinstance(pooled_data, list): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 061232eb11830..f71daf0c19551 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import msgspec +if TYPE_CHECKING: + from vllm.config import ModelConfig + class PoolingParams( msgspec.Struct, @@ -12,14 +15,30 @@ class PoolingParams( """API parameters for pooling models. This is currently a placeholder. Attributes: + dimensions: Reduce the dimensions of embeddings + if model support matryoshka representation. additional_data: Any additional data needed for pooling. """ + + dimensions: Optional[int] = None additional_data: Optional[Any] = None def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams(additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions, + additional_data=self.additional_data) + + def verify(self, model_config: "ModelConfig") -> None: + if self.dimensions is not None: + if not model_config.is_matryoshka: + raise ValueError( + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.') + if self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: return (f"PoolingParams(" + f"dimensions={self.dimensions}, " f"additional_metadata={self.additional_data})")