[Frontend] Using matryoshka_dimensions control the allowed output dimensions. (#16970)

This commit is contained in:
wang.yuqi 2025-04-24 22:06:28 +08:00 committed by GitHub
parent b724afe343
commit 67309a1cb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 172 additions and 76 deletions

View File

@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model
### Manually enable Matryoshka Embeddings ### Manually enable Matryoshka Embeddings
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`. There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online). For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).
Here is an example to serve a model with Matryoshka Embeddings enabled. Here is an example to serve a model with Matryoshka Embeddings enabled.
```text ```text
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}' vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}'
``` ```
### Offline Inference ### Offline Inference
@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \
"input": "Follow the white rabbit.", "input": "Follow the white rabbit.",
"model": "jinaai/jina-embeddings-v3", "model": "jinaai/jina-embeddings-v3",
"encoding_format": "float", "encoding_format": "float",
"dimensions": 1 "dimensions": 32
}' }'
``` ```
Expected output: Expected output:
```json ```json
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}} {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
``` ```
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>

View File

@ -25,11 +25,11 @@ def main():
responses = client.embeddings.create( responses = client.embeddings.create(
input=["Follow the white rabbit."], input=["Follow the white rabbit."],
model=model, model=model,
dimensions=1, dimensions=32,
) )
for data in responses.data: for data in responses.data:
print(data.embedding) # List of float of len 1 print(data.embedding) # List of float of len 32
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,11 +11,12 @@ import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.embedding.utils import check_embeddings_close from ...models.embedding.utils import correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -25,7 +26,7 @@ def server():
"embed", "embed",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", DTYPE,
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"512", "512",
@ -43,9 +44,17 @@ async def client(server):
yield async_client yield async_client
@pytest.fixture(scope="module")
def hf_model(hf_runner):
with hf_runner(MODEL_NAME, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
yield hf_model
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
input_texts = [ input_texts = [
"The chef prepared a delicious meal.", "The chef prepared a delicious meal.",
] ]
@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 11 assert embeddings.usage.prompt_tokens == 11
assert embeddings.usage.total_tokens == 11 assert embeddings.usage.total_tokens == 11
vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)
# test using token IDs # test using token IDs
input_tokens = [1, 1, 1, 1, 1] input_tokens = [1, 1, 1, 1, 1]
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
# test list[str] # test list[str]
input_texts = [ input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.", "The cat sat on the mat.", "A feline was resting on a rug.",
@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.prompt_tokens == 33
assert embeddings.usage.total_tokens == 33 assert embeddings.usage.total_tokens == 33
vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)
# test list[list[int]] # test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]] [25, 32, 64, 77]]
@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding(client: openai.AsyncOpenAI, async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str): model_name: str):
input_texts = [ input_texts = [
"Hello my name is", "Hello my name is",
@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
model=model_name, model=model_name,
encoding_format="float") encoding_format="float")
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
correctness_test(hf_model, input_texts, float_data)
responses_base64 = await client.embeddings.create(input=input_texts, responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name, model=model_name,
@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
np.frombuffer(base64.b64decode(data.embedding), np.frombuffer(base64.b64decode(data.embedding),
dtype="float32").tolist()) dtype="float32").tolist())
check_embeddings_close( correctness_test(hf_model, input_texts, base64_data)
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float",
name_1="base64",
)
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
responses_default = await client.embeddings.create(input=input_texts, responses_default = await client.embeddings.create(input=input_texts,
model=model_name) model=model_name)
default_data = [d.embedding for d in responses_default.data] default_data = [d.embedding for d in responses_default.data]
correctness_test(hf_model, input_texts, default_data)
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float",
name_1="default",
)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -3,73 +3,121 @@
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
""" """
from typing import Optional
import openai import openai
import pytest import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...models.embedding.utils import EmbedModelInfo from ...conftest import HfRunner
from ...models.embedding.utils import EmbedModelInfo, correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODELS = [ MODELS = [
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False), EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
matryoshka_dimensions=[256]),
] ]
input_texts = [ input_texts = [
"The chef prepared a delicious meal.", "The chef prepared a delicious meal.",
] * 3 ]
@pytest.mark.asyncio @pytest.fixture(scope="module", params=MODELS)
@pytest.mark.parametrize("model", MODELS) def model_info(request):
async def test_validating_dimensions(model: EmbedModelInfo): return request.param
@pytest.fixture(scope="module", params=["bfloat16"])
def dtype(request):
return request.param
@pytest.fixture(scope="module")
def server(model_info, dtype: str):
args = [ args = [
"--task", "--task",
"embed", "embed",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", dtype,
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"512", "512"
"--trust_remote_code"
] ]
with RemoteOpenAIServer(model.name, args) as remote_server:
client = remote_server.get_async_client()
async def make_request(dimensions): if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5":
embedding_response = await client.embeddings.create( # Manually enable Matryoshka Embeddings
model=model.name, args.extend([
input=input_texts, "--trust_remote_code", "--hf_overrides",
dimensions=dimensions, '{"matryoshka_dimensions":[256]}'
encoding_format="float", ])
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None with RemoteOpenAIServer(model_info.name, args) as remote_server:
assert len(embeddings.data) == 3 yield remote_server
assert len(embeddings.data[0].embedding) > 0
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens > 0
assert embeddings.usage.total_tokens > 0
if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions
if model.is_matryoshka: @pytest.fixture(scope="module")
for dimensions in [None, 16]: def hf_model(hf_runner, model_info, dtype: str):
await make_request(dimensions) with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
yield hf_model
@pytest.mark.asyncio
async def test_matryoshka(model_info: EmbedModelInfo,
server: RemoteOpenAIServer, hf_model: HfRunner):
client = server.get_async_client()
async def make_request_and_correctness_test(dimensions):
prompts = input_texts * 3
embedding_response = await client.embeddings.create(
model=model_info.name,
input=prompts,
dimensions=dimensions,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) > 0
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens > 0
assert embeddings.usage.total_tokens > 0
if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions
vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, prompts, vllm_outputs, dimensions)
if model_info.is_matryoshka:
valid_dimensions: list[Optional[int]] = [None]
if model_info.matryoshka_dimensions is not None:
valid_dimensions += model_info.matryoshka_dimensions[:2]
for dimensions in valid_dimensions:
await make_request_and_correctness_test(dimensions)
invalid_dimensions: list[Optional[int]] = [-1]
if model_info.matryoshka_dimensions is not None:
assert 5 not in model_info.matryoshka_dimensions
invalid_dimensions.append(5)
for dimensions in invalid_dimensions:
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
for dimensions in [-1]: await make_request_and_correctness_test(dimensions)
await make_request(dimensions)
else: else:
for dimensions in [None]: for dimensions in [None]:
await make_request(dimensions) await make_request_and_correctness_test(dimensions)
for dimensions in [-1, 16]:
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
for dimensions in [-1, 16]: await make_request_and_correctness_test(dimensions)
await make_request(dimensions)

View File

@ -153,14 +153,24 @@ def test_matryoshka(
with vllm_runner(model, task="embed", dtype=dtype, with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model: max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode( matryoshka_dimensions = (
example_prompts, vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
pooling_params=PoolingParams(dimensions=dimensions)) assert matryoshka_dimensions is not None
check_embeddings_close( if dimensions not in matryoshka_dimensions:
embeddings_0_lst=hf_outputs, with pytest.raises(ValueError):
embeddings_1_lst=vllm_outputs, vllm_model.encode(
name_0="hf", example_prompts,
name_1="vllm", pooling_params=PoolingParams(dimensions=dimensions))
tol=1e-2, else:
) vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple from typing import NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions):
class EmbedModelInfo(NamedTuple): class EmbedModelInfo(NamedTuple):
name: str name: str
is_matryoshka: bool is_matryoshka: bool
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = "" architecture: str = ""
enable_test: bool = True enable_test: bool = True
def correctness_test(hf_model,
inputs,
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -1248,6 +1248,10 @@ class ModelConfig:
return (hasattr(self.hf_config, "matryoshka_dimensions") return (hasattr(self.hf_config, "matryoshka_dimensions")
or getattr(self.hf_config, "is_matryoshka", False)) or getattr(self.hf_config, "is_matryoshka", False))
@property
def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None)
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]

View File

@ -35,7 +35,16 @@ class PoolingParams(
f'Model "{model_config.served_model_name}" does not ' f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, ' f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.') f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
mds = model_config.matryoshka_dimensions
if mds is not None:
if self.dimensions not in mds:
raise ValueError(
f'Model "{model_config.served_model_name}" '
f'only supports {str(mds)} matryoshka dimensions, '
f'use other output dimensions will '
f'lead to poor results.')
elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str: def __repr__(self) -> str: