[Frontend]: Support base64 embedding (#5935)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
llmpros 2024-06-30 08:53:00 -07:00 committed by GitHub
parent 2be6955a3f
commit c6c240aa0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 14 deletions

View File

@ -1,3 +1,6 @@
import base64
import numpy as np
import openai import openai
import pytest import pytest
import ray import ray
@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17 assert embeddings.usage.total_tokens == 17
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
]
responses_float = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float")
responses_base64 = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="base64")
decoded_responses_base64_data = []
for data in responses_base64.data:
decoded_responses_base64_data.append(
np.frombuffer(base64.b64decode(data.embedding),
dtype="float").tolist())
assert responses_float.data[0].embedding == decoded_responses_base64_data[
0]
assert responses_float.data[1].embedding == decoded_responses_base64_data[
1]

View File

@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
class EmbeddingResponseData(BaseModel): class EmbeddingResponseData(BaseModel):
index: int index: int
object: str = "embedding" object: str = "embedding"
embedding: List[float] embedding: Union[List[float], str]
class EmbeddingResponse(BaseModel): class EmbeddingResponse(BaseModel):

View File

@ -1,6 +1,8 @@
import base64
import time import time
from typing import AsyncIterator, List, Optional, Tuple from typing import AsyncIterator, List, Optional, Tuple
import numpy as np
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -20,19 +22,18 @@ TypeTokenIDs = List[int]
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], final_res_batch: List[EmbeddingRequestOutput], request_id: str,
request_id: str, created_time: int, model_name: str,
created_time: int, encoding_format: str) -> EmbeddingResponse:
model_name: str,
) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
assert final_res is not None assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
embedding_data = EmbeddingResponseData( if encoding_format == "base64":
index=idx, embedding=final_res.outputs.embedding) embedding = base64.b64encode(np.array(embedding))
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids) num_prompt_tokens += len(prompt_token_ids)
@ -72,10 +73,8 @@ class OpenAIServingEmbedding(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
# Return error for unsupported features. encoding_format = (request.encoding_format
if request.encoding_format == "base64": if request.encoding_format else "float")
return self.create_error_response(
"base64 encoding is not currently supported")
if request.dimensions is not None: if request.dimensions is not None:
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
@ -129,7 +128,8 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name) final_res_batch, request_id, created_time, model_name,
encoding_format)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))