mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 18:17:55 +08:00
[Frontend]: Support base64 embedding (#5935)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
2be6955a3f
commit
c6c240aa0a
@ -1,3 +1,6 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
assert embeddings.usage.completion_tokens == 0
|
assert embeddings.usage.completion_tokens == 0
|
||||||
assert embeddings.usage.prompt_tokens == 17
|
assert embeddings.usage.prompt_tokens == 17
|
||||||
assert embeddings.usage.total_tokens == 17
|
assert embeddings.usage.total_tokens == 17
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[EMBEDDING_MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
input_texts = [
|
||||||
|
"Hello my name is",
|
||||||
|
"The best thing about vLLM is that it supports many different models"
|
||||||
|
]
|
||||||
|
|
||||||
|
responses_float = await embedding_client.embeddings.create(
|
||||||
|
input=input_texts, model=model_name, encoding_format="float")
|
||||||
|
|
||||||
|
responses_base64 = await embedding_client.embeddings.create(
|
||||||
|
input=input_texts, model=model_name, encoding_format="base64")
|
||||||
|
|
||||||
|
decoded_responses_base64_data = []
|
||||||
|
for data in responses_base64.data:
|
||||||
|
decoded_responses_base64_data.append(
|
||||||
|
np.frombuffer(base64.b64decode(data.embedding),
|
||||||
|
dtype="float").tolist())
|
||||||
|
|
||||||
|
assert responses_float.data[0].embedding == decoded_responses_base64_data[
|
||||||
|
0]
|
||||||
|
assert responses_float.data[1].embedding == decoded_responses_base64_data[
|
||||||
|
1]
|
||||||
|
|||||||
@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|||||||
class EmbeddingResponseData(BaseModel):
|
class EmbeddingResponseData(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
object: str = "embedding"
|
object: str = "embedding"
|
||||||
embedding: List[float]
|
embedding: Union[List[float], str]
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
class EmbeddingResponse(BaseModel):
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import AsyncIterator, List, Optional, Tuple
|
from typing import AsyncIterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -20,19 +22,18 @@ TypeTokenIDs = List[int]
|
|||||||
|
|
||||||
|
|
||||||
def request_output_to_embedding_response(
|
def request_output_to_embedding_response(
|
||||||
final_res_batch: List[EmbeddingRequestOutput],
|
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||||
request_id: str,
|
created_time: int, model_name: str,
|
||||||
created_time: int,
|
encoding_format: str) -> EmbeddingResponse:
|
||||||
model_name: str,
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
data: List[EmbeddingResponseData] = []
|
data: List[EmbeddingResponseData] = []
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
for idx, final_res in enumerate(final_res_batch):
|
for idx, final_res in enumerate(final_res_batch):
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
prompt_token_ids = final_res.prompt_token_ids
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
|
embedding = final_res.outputs.embedding
|
||||||
embedding_data = EmbeddingResponseData(
|
if encoding_format == "base64":
|
||||||
index=idx, embedding=final_res.outputs.embedding)
|
embedding = base64.b64encode(np.array(embedding))
|
||||||
|
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||||
data.append(embedding_data)
|
data.append(embedding_data)
|
||||||
|
|
||||||
num_prompt_tokens += len(prompt_token_ids)
|
num_prompt_tokens += len(prompt_token_ids)
|
||||||
@ -72,10 +73,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
# Return error for unsupported features.
|
encoding_format = (request.encoding_format
|
||||||
if request.encoding_format == "base64":
|
if request.encoding_format else "float")
|
||||||
return self.create_error_response(
|
|
||||||
"base64 encoding is not currently supported")
|
|
||||||
if request.dimensions is not None:
|
if request.dimensions is not None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"dimensions is currently not supported")
|
"dimensions is currently not supported")
|
||||||
@ -129,7 +128,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
response = request_output_to_embedding_response(
|
response = request_output_to_embedding_response(
|
||||||
final_res_batch, request_id, created_time, model_name)
|
final_res_batch, request_id, created_time, model_name,
|
||||||
|
encoding_format)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user