mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:44:54 +08:00
778 lines
26 KiB
Python
778 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""The request function for API endpoints."""
|
|
|
|
import io
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from collections.abc import Awaitable
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Literal, Protocol
|
|
|
|
import aiohttp
|
|
import regex as re
|
|
from tqdm.asyncio import tqdm
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
|
|
|
|
|
class StreamedResponseHandler:
|
|
"""Handles streaming HTTP responses by accumulating chunks until complete
|
|
messages are available."""
|
|
|
|
def __init__(self):
|
|
self.buffer = ""
|
|
|
|
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
|
"""Add a chunk of bytes to the buffer and return any complete
|
|
messages."""
|
|
chunk_str = chunk_bytes.decode("utf-8")
|
|
self.buffer += chunk_str
|
|
|
|
messages = []
|
|
|
|
# Split by double newlines (SSE message separator)
|
|
while "\n\n" in self.buffer:
|
|
message, self.buffer = self.buffer.split("\n\n", 1)
|
|
message = message.strip()
|
|
if message:
|
|
messages.append(message)
|
|
|
|
# if self.buffer is not empty, check if it is a complete message
|
|
# by removing data: prefix and check if it is a valid JSON
|
|
if self.buffer.startswith("data: "):
|
|
message_content = self.buffer.removeprefix("data: ").strip()
|
|
if message_content == "[DONE]":
|
|
messages.append(self.buffer.strip())
|
|
self.buffer = ""
|
|
elif message_content:
|
|
try:
|
|
json.loads(message_content)
|
|
messages.append(self.buffer.strip())
|
|
self.buffer = ""
|
|
except json.JSONDecodeError:
|
|
# Incomplete JSON, wait for more chunks.
|
|
pass
|
|
|
|
return messages
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncInput:
|
|
"""The input for the request function."""
|
|
|
|
prompt: str | list[str]
|
|
api_url: str
|
|
prompt_len: int
|
|
output_len: int
|
|
model: str
|
|
model_name: str | None = None
|
|
logprobs: int | None = None
|
|
extra_headers: dict | None = None
|
|
extra_body: dict | None = None
|
|
multi_modal_content: dict | list[dict] | None = None
|
|
ignore_eos: bool = False
|
|
language: str | None = None
|
|
request_id: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncOutput:
|
|
"""The output of the request function including metrics."""
|
|
|
|
generated_text: str = ""
|
|
success: bool = False
|
|
latency: float = 0.0
|
|
output_tokens: int = 0
|
|
ttft: float = 0.0 # Time to first token
|
|
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
|
tpot: float = 0.0 # avg next-token latencies
|
|
prompt_len: int = 0
|
|
error: str = ""
|
|
start_time: float = 0.0
|
|
|
|
|
|
class RequestFunc(Protocol):
|
|
def __call__(
|
|
self,
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> Awaitable[RequestFuncOutput]: ...
|
|
|
|
|
|
def _validate_api_url(
|
|
api_url: str,
|
|
api_name: str,
|
|
expected_suffixes: str | set[str],
|
|
) -> None:
|
|
if isinstance(expected_suffixes, str):
|
|
expected_suffixes = {expected_suffixes}
|
|
|
|
expected_suffixes = {*expected_suffixes, "profile"}
|
|
|
|
if not api_url.endswith(tuple(expected_suffixes)):
|
|
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
|
|
|
|
|
|
def _update_payload_common(
|
|
payload: dict[str, Any],
|
|
request_func_input: RequestFuncInput,
|
|
) -> None:
|
|
if request_func_input.ignore_eos:
|
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
|
if request_func_input.extra_body:
|
|
payload.update(request_func_input.extra_body)
|
|
|
|
|
|
def _update_headers_common(
|
|
headers: dict[str, Any],
|
|
request_func_input: RequestFuncInput,
|
|
) -> None:
|
|
if request_func_input.extra_headers:
|
|
headers |= request_func_input.extra_headers
|
|
if request_func_input.request_id:
|
|
headers["x-request-id"] = request_func_input.request_id
|
|
|
|
|
|
async def async_request_openai_completions(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
"""The async request function for the OpenAI Completions API.
|
|
|
|
Args:
|
|
request_func_input: The input for the request function.
|
|
pbar: The progress bar to display the progress.
|
|
|
|
Returns:
|
|
The output of the request function.
|
|
"""
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "OpenAI Completions API", "completions")
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"prompt": request_func_input.prompt,
|
|
"temperature": 0.0,
|
|
"repetition_penalty": 1.0,
|
|
"max_tokens": request_func_input.output_len,
|
|
"logprobs": request_func_input.logprobs,
|
|
"stream": True,
|
|
"stream_options": {
|
|
"include_usage": True,
|
|
},
|
|
}
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
st = time.perf_counter()
|
|
output.start_time = st
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
|
if response.status == 200:
|
|
first_chunk_received = False
|
|
handler = StreamedResponseHandler()
|
|
|
|
async for chunk_bytes in response.content.iter_any():
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
messages = handler.add_chunk(chunk_bytes)
|
|
for message in messages:
|
|
# NOTE: SSE comments (often used as pings) start with
|
|
# a colon. These are not JSON data payload and should
|
|
# be skipped.
|
|
if message.startswith(":"):
|
|
continue
|
|
|
|
chunk = message.removeprefix("data: ")
|
|
|
|
if chunk != "[DONE]":
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if choices := data.get("choices"):
|
|
# Note that text could be empty here
|
|
# e.g. for special tokens
|
|
text = choices[0].get("text")
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if not first_chunk_received:
|
|
first_chunk_received = True
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
generated_text += text or ""
|
|
elif usage := data.get("usage"):
|
|
output.output_tokens = usage.get("completion_tokens")
|
|
if first_chunk_received:
|
|
output.success = True
|
|
else:
|
|
output.success = False
|
|
output.error = (
|
|
"Never received a valid chunk to calculate TTFT."
|
|
"This response will be marked as failed!"
|
|
)
|
|
output.generated_text = generated_text
|
|
output.latency = most_recent_timestamp - st
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
def _get_chat_content(
|
|
request_func_input: RequestFuncInput,
|
|
mm_position: Literal["first", "last"] = "last",
|
|
) -> list[dict[str, Any]]:
|
|
text_contents = [{"type": "text", "text": request_func_input.prompt}]
|
|
|
|
mm_contents = []
|
|
if request_func_input.multi_modal_content:
|
|
mm_content = request_func_input.multi_modal_content
|
|
if isinstance(mm_content, list):
|
|
mm_contents.extend(request_func_input.multi_modal_content)
|
|
elif isinstance(mm_content, dict):
|
|
mm_contents.append(request_func_input.multi_modal_content)
|
|
else:
|
|
raise TypeError(
|
|
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
|
)
|
|
|
|
if mm_position == "first":
|
|
return mm_contents + text_contents
|
|
|
|
return text_contents + mm_contents
|
|
|
|
|
|
async def async_request_openai_chat_completions(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
mm_position: Literal["first", "last"] = "last",
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
|
|
|
|
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"messages": [
|
|
{"role": "user", "content": content},
|
|
],
|
|
"temperature": 0.0,
|
|
"max_completion_tokens": request_func_input.output_len,
|
|
"stream": True,
|
|
"stream_options": {
|
|
"include_usage": True,
|
|
},
|
|
}
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
output.start_time = st
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
|
if response.status == 200:
|
|
handler = StreamedResponseHandler()
|
|
async for chunk_bytes in response.content.iter_any():
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
messages = handler.add_chunk(chunk_bytes)
|
|
for message in messages:
|
|
# NOTE: SSE comments (often used as pings) start with
|
|
# a colon. These are not JSON data payload and should
|
|
# be skipped.
|
|
if message.startswith(":"):
|
|
continue
|
|
|
|
chunk = message.removeprefix("data: ")
|
|
|
|
if chunk != "[DONE]":
|
|
timestamp = time.perf_counter()
|
|
data = json.loads(chunk)
|
|
|
|
if choices := data.get("choices"):
|
|
content = choices[0]["delta"].get("content")
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = timestamp - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
generated_text += content or ""
|
|
elif usage := data.get("usage"):
|
|
output.output_tokens = usage.get("completion_tokens")
|
|
|
|
most_recent_timestamp = timestamp
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = most_recent_timestamp - st
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_openai_audio(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
# Lazy import without PlaceholderModule to avoid vllm dep.
|
|
import soundfile
|
|
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
|
|
|
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"temperature": 0.0,
|
|
"max_completion_tokens": request_func_input.output_len,
|
|
"stream": True,
|
|
"language": "en",
|
|
# Flattened due to multipart/form-data
|
|
"stream_include_usage": True,
|
|
"stream_continuous_usage_stats": True,
|
|
}
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
# Send audio file
|
|
def to_bytes(y, sr):
|
|
buffer = io.BytesIO()
|
|
soundfile.write(buffer, y, sr, format="WAV")
|
|
buffer.seek(0)
|
|
return buffer
|
|
|
|
mm_audio = request_func_input.multi_modal_content
|
|
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
|
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
|
with to_bytes(*mm_audio["audio"]) as f:
|
|
form = aiohttp.FormData()
|
|
form.add_field("file", f, content_type="audio/wav")
|
|
for key, value in payload.items():
|
|
form.add_field(key, str(value))
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
output.start_time = st
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(
|
|
url=api_url, data=form, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
handler = StreamedResponseHandler()
|
|
|
|
async for chunk_bytes in response.content.iter_any():
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
messages = handler.add_chunk(chunk_bytes)
|
|
for message in messages:
|
|
chunk = message.decode("utf-8").removeprefix("data: ")
|
|
if chunk != "[DONE]":
|
|
timestamp = time.perf_counter()
|
|
data = json.loads(chunk)
|
|
|
|
if choices := data.get("choices"):
|
|
content = choices[0]["delta"].get("content")
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = timestamp - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(
|
|
timestamp - most_recent_timestamp
|
|
)
|
|
|
|
generated_text += content or ""
|
|
elif usage := data.get("usage"):
|
|
output.output_tokens = usage.get(
|
|
"completion_tokens"
|
|
)
|
|
|
|
most_recent_timestamp = timestamp
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = most_recent_timestamp - st
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def _run_pooling_request(
|
|
session: aiohttp.ClientSession,
|
|
api_url: str,
|
|
payload: dict[str, Any],
|
|
headers: dict[str, Any],
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
output = RequestFuncOutput()
|
|
st = time.perf_counter()
|
|
output.start_time = st
|
|
try:
|
|
async with session.post(url=api_url, headers=headers, json=payload) as response:
|
|
if response.status == 200:
|
|
output.ttft = output.latency = time.perf_counter() - st
|
|
|
|
if payload.get("encoding_format", "float") == "bytes":
|
|
metadata = json.loads(response.headers["metadata"])
|
|
usage = metadata.get("usage", {})
|
|
else:
|
|
data = await response.json()
|
|
usage = data.get("usage", {})
|
|
|
|
output.success = True
|
|
output.generated_text = ""
|
|
output.prompt_len = usage.get("prompt_tokens", 0)
|
|
else:
|
|
output.success = False
|
|
output.error = response.reason or ""
|
|
except Exception as e:
|
|
output.success = False
|
|
output.error = str(e)
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_openai_embeddings(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"input": request_func_input.prompt,
|
|
# Many embedding models have short context length,
|
|
# this is to avoid dropping some of the requests.
|
|
"truncate_prompt_tokens": -1,
|
|
}
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
return await _run_pooling_request(
|
|
session,
|
|
api_url,
|
|
payload=payload,
|
|
headers=headers,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
async def async_request_vllm_rerank(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "vLLM score API", "rerank")
|
|
|
|
assert (
|
|
isinstance(request_func_input.prompt, list)
|
|
and len(request_func_input.prompt) > 1
|
|
)
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"query": request_func_input.prompt[0],
|
|
"documents": request_func_input.prompt[1:],
|
|
# Many reranker models have short context length,
|
|
# this is to avoid dropping some of the requests.
|
|
"truncate_prompt_tokens": -1,
|
|
}
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
return await _run_pooling_request(
|
|
session,
|
|
api_url,
|
|
payload=payload,
|
|
headers=headers,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
async def async_request_openai_embeddings_chat(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
mm_position: Literal["first", "last"] = "last",
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
|
|
|
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
"messages": [
|
|
{"role": "user", "content": content},
|
|
],
|
|
# Many embedding models have short context length,
|
|
# this is to avoid dropping some of the requests.
|
|
"truncate_prompt_tokens": -1,
|
|
}
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
return await _run_pooling_request(
|
|
session,
|
|
api_url,
|
|
payload=payload,
|
|
headers=headers,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
def _try_extract_request_idx(request_func_input: RequestFuncInput):
|
|
if request_func_input.request_id:
|
|
match = re.search(r"(\d+)$", request_func_input.request_id)
|
|
if match:
|
|
try:
|
|
return int(match.group(1))
|
|
except ValueError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def _preprocess_clip(request_func_input: RequestFuncInput):
|
|
if request_func_input.multi_modal_content:
|
|
# Image input
|
|
request_func_input.prompt = ""
|
|
|
|
|
|
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
|
|
if request_func_input.multi_modal_content:
|
|
request_idx = _try_extract_request_idx(request_func_input)
|
|
|
|
# Adjust the ratio manually if needed.
|
|
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
|
|
|
|
if use_image_only_prompt:
|
|
# Image input
|
|
request_func_input.prompt = "Represent the given image."
|
|
else:
|
|
# Text+Image input
|
|
request_func_input.prompt = (
|
|
f"Represent the given image with the following question: "
|
|
f"{request_func_input.prompt}"
|
|
)
|
|
|
|
|
|
async def async_request_openai_embeddings_clip(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
_preprocess_clip(request_func_input)
|
|
|
|
return await async_request_openai_embeddings_chat(
|
|
request_func_input,
|
|
session,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
async def async_request_openai_embeddings_vlm2vec(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
_preprocess_vlm2vec(request_func_input)
|
|
|
|
return await async_request_openai_embeddings_chat(
|
|
request_func_input,
|
|
session,
|
|
pbar=pbar,
|
|
mm_position="first",
|
|
)
|
|
|
|
|
|
async def async_request_infinity_embeddings(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
|
|
|
|
payload = {
|
|
"model": request_func_input.model_name
|
|
if request_func_input.model_name
|
|
else request_func_input.model,
|
|
}
|
|
|
|
if request_func_input.prompt:
|
|
payload["input"] = request_func_input.prompt
|
|
else:
|
|
mm_content = request_func_input.multi_modal_content
|
|
assert isinstance(mm_content, dict)
|
|
|
|
mm_type = mm_content["type"]
|
|
payload["input"] = mm_content[mm_type]["url"]
|
|
payload["modality"] = mm_type.split("_", 1)[0]
|
|
|
|
_update_payload_common(payload, request_func_input)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
_update_headers_common(headers, request_func_input)
|
|
|
|
return await _run_pooling_request(
|
|
session,
|
|
api_url,
|
|
payload=payload,
|
|
headers=headers,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
async def async_request_infinity_embeddings_clip(
|
|
request_func_input: RequestFuncInput,
|
|
session: aiohttp.ClientSession,
|
|
pbar: tqdm | None = None,
|
|
) -> RequestFuncOutput:
|
|
_preprocess_clip(request_func_input)
|
|
|
|
return await async_request_infinity_embeddings(
|
|
request_func_input,
|
|
session,
|
|
pbar=pbar,
|
|
)
|
|
|
|
|
|
# TODO: Add more request functions for different API protocols.
|
|
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
|
"vllm": async_request_openai_completions,
|
|
"openai": async_request_openai_completions,
|
|
"openai-chat": async_request_openai_chat_completions,
|
|
"openai-audio": async_request_openai_audio,
|
|
"openai-embeddings": async_request_openai_embeddings,
|
|
"openai-embeddings-chat": async_request_openai_embeddings_chat,
|
|
"openai-embeddings-clip": async_request_openai_embeddings_clip,
|
|
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
|
|
# Infinity embedding server: https://github.com/michaelfeil/infinity
|
|
"infinity-embeddings": async_request_infinity_embeddings,
|
|
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
|
|
# (Infinity embedding server does not support vlm2vec)
|
|
"vllm-rerank": async_request_vllm_rerank,
|
|
}
|
|
|
|
OPENAI_COMPATIBLE_BACKENDS = [
|
|
k
|
|
for k, v in ASYNC_REQUEST_FUNCS.items()
|
|
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
|
]
|