mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 00:56:58 +08:00
Adds truncate_prompt_tokens param for embeddings creation (#8999)
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
parent
26aa325f4f
commit
0dcc8cbe5a
@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
0].embedding
|
0].embedding
|
||||||
assert responses_float.data[1].embedding == responses_default.data[
|
assert responses_float.data[1].embedding == responses_default.data[
|
||||||
1].embedding
|
1].embedding
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[EMBEDDING_MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_single_embedding_truncation(
|
||||||
|
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||||
|
input_texts = [
|
||||||
|
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||||
|
]
|
||||||
|
|
||||||
|
# test single embedding
|
||||||
|
embeddings = await embedding_client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=input_texts,
|
||||||
|
extra_body={"truncate_prompt_tokens": 10})
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert len(embeddings.data) == 1
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 10
|
||||||
|
assert embeddings.usage.total_tokens == 10
|
||||||
|
|
||||||
|
input_tokens = [
|
||||||
|
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
||||||
|
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
||||||
|
]
|
||||||
|
embeddings = await embedding_client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=input_tokens,
|
||||||
|
extra_body={"truncate_prompt_tokens": 10})
|
||||||
|
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert len(embeddings.data) == 1
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 10
|
||||||
|
assert embeddings.usage.total_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[EMBEDDING_MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_single_embedding_truncation_invalid(
|
||||||
|
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||||
|
input_texts = [
|
||||||
|
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
embeddings = await embedding_client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=input_texts,
|
||||||
|
extra_body={"truncate_prompt_tokens": 8193})
|
||||||
|
assert "error" in embeddings.object
|
||||||
|
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||||
|
"Please, select a smaller truncation size." in embeddings.message
|
||||||
|
|||||||
@ -671,6 +671,7 @@ class EmbeddingRequest(OpenAIBaseModel):
|
|||||||
encoding_format: Literal["float", "base64"] = "float"
|
encoding_format: Literal["float", "base64"] = "float"
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
|
||||||
# doc: begin-embedding-pooling-params
|
# doc: begin-embedding-pooling-params
|
||||||
additional_data: Optional[Any] = None
|
additional_data: Optional[Any] = None
|
||||||
|
|||||||
@ -110,6 +110,17 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
request_id = f"embd-{random_uuid()}"
|
request_id = f"embd-{random_uuid()}"
|
||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
|
|
||||||
|
truncate_prompt_tokens = None
|
||||||
|
|
||||||
|
if request.truncate_prompt_tokens is not None:
|
||||||
|
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||||
|
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
"truncate_prompt_tokens value is "
|
||||||
|
"greater than max_model_len."
|
||||||
|
" Please, select a smaller truncation size.")
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
@ -123,11 +134,9 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
prompts = list(
|
prompts = list(
|
||||||
self._tokenize_prompt_input_or_inputs(
|
self._tokenize_prompt_input_or_inputs(request, tokenizer,
|
||||||
request,
|
request.input,
|
||||||
tokenizer,
|
truncate_prompt_tokens))
|
||||||
request.input,
|
|
||||||
))
|
|
||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
for i, prompt_inputs in enumerate(prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user