mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[Benchmark] Enable MM Embedding benchmarks (#26310)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7cd95dc8a3
commit
44b9af5bb2
@ -67,13 +67,13 @@ Legend:
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
First start serving your model
|
||||
First start serving your model:
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B
|
||||
```
|
||||
|
||||
Then run the benchmarking script
|
||||
Then run the benchmarking script:
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
@ -87,7 +87,7 @@ vllm bench serve \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
If successful, you will see the following output:
|
||||
|
||||
```text
|
||||
============ Serving Benchmark Result ============
|
||||
@ -125,7 +125,7 @@ If the dataset you want to benchmark is not supported yet in vLLM, even then you
|
||||
|
||||
```bash
|
||||
# start server
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
```bash
|
||||
@ -167,7 +167,7 @@ vllm bench serve \
|
||||
##### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
@ -184,7 +184,7 @@ vllm bench serve \
|
||||
##### Spec Bench Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
@ -366,7 +366,6 @@ Total num output tokens: 1280
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_USE_V1=1 \
|
||||
vllm bench throughput \
|
||||
--dataset-name=hf \
|
||||
--dataset-path=likaixin/InstructCoder \
|
||||
@ -781,6 +780,104 @@ This should be seen as an edge case, and if this behavior can be avoided by sett
|
||||
|
||||
</details>
|
||||
|
||||
#### Embedding Benchmark
|
||||
|
||||
Benchmark the performance of embedding requests in vLLM.
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
##### Text Embeddings
|
||||
|
||||
Unlike generative models which use Completions API or Chat Completions API,
|
||||
you should set `--backend openai-embeddings` and `--endpoint /v1/embeddings` to use the Embeddings API.
|
||||
|
||||
You can use any text dataset to benchmark the model, such as ShareGPT.
|
||||
|
||||
Start the server:
|
||||
|
||||
```bash
|
||||
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
||||
```
|
||||
|
||||
Run the benchmark:
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench serve \
|
||||
--model jinaai/jina-embeddings-v3 \
|
||||
--backend openai-embeddings \
|
||||
--endpoint /v1/embeddings \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
```
|
||||
|
||||
##### Multi-modal Embeddings
|
||||
|
||||
Unlike generative models which use Completions API or Chat Completions API,
|
||||
you should set `--endpoint /v1/embeddings` to use the Embeddings API. The backend to use depends on the model:
|
||||
|
||||
- CLIP: `--backend openai-embeddings-clip`
|
||||
- VLM2Vec: `--backend openai-embeddings-vlm2vec`
|
||||
|
||||
For other models, please add your own implementation inside <gh-file:vllm/benchmarks/lib/endpoint_request_func.py> to match the expected instruction format.
|
||||
|
||||
You can use any text or multi-modal dataset to benchmark the model, as long as the model supports it.
|
||||
For example, you can use ShareGPT and VisionArena to benchmark vision-language embeddings.
|
||||
|
||||
Serve and benchmark CLIP:
|
||||
|
||||
```bash
|
||||
# Run this in another process
|
||||
vllm serve openai/clip-vit-base-patch32
|
||||
|
||||
# Run these one by one after the server is up
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench serve \
|
||||
--model openai/clip-vit-base-patch32 \
|
||||
--backend openai-embeddings-clip \
|
||||
--endpoint /v1/embeddings \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
vllm bench serve \
|
||||
--model openai/clip-vit-base-patch32 \
|
||||
--backend openai-embeddings-clip \
|
||||
--endpoint /v1/embeddings \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat
|
||||
```
|
||||
|
||||
Serve and benchmark VLM2Vec:
|
||||
|
||||
```bash
|
||||
# Run this in another process
|
||||
vllm serve TIGER-Lab/VLM2Vec-Full --runner pooling \
|
||||
--trust-remote-code \
|
||||
--chat-template examples/template_vlm2vec_phi3v.jinja
|
||||
|
||||
# Run these one by one after the server is up
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench serve \
|
||||
--model TIGER-Lab/VLM2Vec-Full \
|
||||
--backend openai-embeddings-vlm2vec \
|
||||
--endpoint /v1/embeddings \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
vllm bench serve \
|
||||
--model TIGER-Lab/VLM2Vec-Full \
|
||||
--backend openai-embeddings-vlm2vec \
|
||||
--endpoint /v1/embeddings \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
[](){ #performance-benchmarks }
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
@ -1582,10 +1582,10 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
"like to add support for additional dataset formats."
|
||||
)
|
||||
|
||||
if dataset_class.IS_MULTIMODAL and args.backend not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
if dataset_class.IS_MULTIMODAL and not (
|
||||
args.backend in ("openai-chat", "openai-audio")
|
||||
or "openai-embeddings-" in args.backend
|
||||
):
|
||||
# multi-modal benchmark is only available on OpenAI Chat
|
||||
# endpoint-type.
|
||||
raise ValueError(
|
||||
|
||||
@ -10,9 +10,10 @@ import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Protocol, Union
|
||||
from typing import Any, Literal, Optional, Protocol, Union
|
||||
|
||||
import aiohttp
|
||||
import regex as re
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
@ -103,6 +104,40 @@ class RequestFunc(Protocol):
|
||||
) -> Awaitable[RequestFuncOutput]: ...
|
||||
|
||||
|
||||
def _validate_api_url(
|
||||
api_url: str,
|
||||
api_name: str,
|
||||
expected_suffixes: Union[str, set[str]],
|
||||
) -> None:
|
||||
if isinstance(expected_suffixes, str):
|
||||
expected_suffixes = {expected_suffixes}
|
||||
|
||||
expected_suffixes = {*expected_suffixes, "profile"}
|
||||
|
||||
if not api_url.endswith(tuple(expected_suffixes)):
|
||||
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
|
||||
|
||||
|
||||
def _update_payload_common(
|
||||
payload: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
|
||||
|
||||
def _update_headers_common(
|
||||
headers: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
@ -118,9 +153,7 @@ async def async_request_openai_completions(
|
||||
The output of the request function.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("completions", "profile")), (
|
||||
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
)
|
||||
_validate_api_url(api_url, "OpenAI Completions API", "completions")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
@ -136,15 +169,12 @@ async def async_request_openai_completions(
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -222,27 +252,41 @@ async def async_request_openai_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
def _get_chat_content(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("chat/completions", "profile")), (
|
||||
"OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
)
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> list[dict[str, Any]]:
|
||||
text_contents = [{"type": "text", "text": request_func_input.prompt}]
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
mm_contents = []
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
content.extend(mm_content)
|
||||
mm_contents.extend(request_func_input.multi_modal_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
mm_contents.append(request_func_input.multi_modal_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
|
||||
if mm_position == "first":
|
||||
return mm_contents + text_contents
|
||||
|
||||
return text_contents + mm_contents
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
@ -257,18 +301,13 @@ async def async_request_openai_chat_completions(
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -343,10 +382,7 @@ async def async_request_openai_audio(
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("transcriptions", "translations")), (
|
||||
"OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
)
|
||||
"or `translations`."
|
||||
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
@ -361,15 +397,12 @@ async def async_request_openai_audio(
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
@ -451,26 +484,13 @@ async def async_request_openai_audio(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
async def _run_openai_embeddings(
|
||||
session: aiohttp.ClientSession,
|
||||
api_url: str,
|
||||
payload: dict[str, Any],
|
||||
headers: dict[str, Any],
|
||||
pbar: Optional[tqdm] = None,
|
||||
):
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("embeddings"), (
|
||||
"OpenAI Embeddings API URL must end with 'embeddings'."
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
}
|
||||
|
||||
) -> RequestFuncOutput:
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
@ -494,6 +514,137 @@ async def async_request_openai_embeddings(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_openai_embeddings(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_chat(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_openai_embeddings(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
if request_func_input.multi_modal_content:
|
||||
# Image input
|
||||
request_func_input.prompt = ""
|
||||
|
||||
# max_model_len=77 is too short for most datasets,
|
||||
# so by default we truncate the prompt to max_model_len
|
||||
if request_func_input.extra_body is None:
|
||||
request_func_input.extra_body = {}
|
||||
if "truncate_prompt_tokens" not in request_func_input.extra_body:
|
||||
request_func_input.extra_body["truncate_prompt_tokens"] = -1
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
def _try_extract_request_idx(request_func_input: RequestFuncInput):
|
||||
if request_func_input.request_id:
|
||||
match = re.search(r"(\d+)$", request_func_input.request_id)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_vlm2vec(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
if request_func_input.multi_modal_content:
|
||||
request_idx = _try_extract_request_idx(request_func_input)
|
||||
|
||||
# Adjust the ratio manually if needed.
|
||||
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
|
||||
|
||||
if use_image_only_prompt:
|
||||
# Image input
|
||||
request_func_input.prompt = "Represent the given image."
|
||||
else:
|
||||
# Text+Image input
|
||||
request_func_input.prompt = (
|
||||
f"Represent the given image with the following question: "
|
||||
f"{request_func_input.prompt}"
|
||||
)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
mm_position="first",
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
@ -501,6 +652,9 @@ ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
"openai-embeddings-chat": async_request_openai_embeddings_chat,
|
||||
"openai-embeddings-clip": async_request_openai_embeddings_clip,
|
||||
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
|
||||
@ -465,6 +465,7 @@ def calculate_metrics(
|
||||
|
||||
|
||||
async def benchmark(
|
||||
task_type: TaskType,
|
||||
endpoint_type: str,
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
@ -490,18 +491,10 @@ async def benchmark(
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
ready_check_timeout_sec: int = 600,
|
||||
):
|
||||
task_type = (
|
||||
TaskType.EMBEDDING
|
||||
if api_url.endswith("/v1/embeddings")
|
||||
else TaskType.GENERATION
|
||||
)
|
||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||
if task_type == TaskType.EMBEDDING:
|
||||
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
|
||||
else:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {endpoint_type}")
|
||||
try:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown backend: {endpoint_type}") from None
|
||||
|
||||
# Reuses connections across requests to reduce TLS handshake overhead.
|
||||
connector = aiohttp.TCPConnector(
|
||||
@ -1310,36 +1303,43 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
input_requests = get_samples(args, tokenizer)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
backend = args.backend
|
||||
task_type = TaskType.EMBEDDING if "embeddings" in backend else TaskType.GENERATION
|
||||
|
||||
# Collect the sampling parameters.
|
||||
sampling_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if task_type == TaskType.GENERATION:
|
||||
sampling_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError(
|
||||
"Sampling parameters are only supported by openai-compatible backends."
|
||||
)
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError(
|
||||
"Sampling parameters are only supported by openai-compatible backends."
|
||||
)
|
||||
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
else:
|
||||
sampling_params = {}
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = await benchmark(
|
||||
endpoint_type=args.backend,
|
||||
task_type=task_type,
|
||||
endpoint_type=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
|
||||
@ -498,14 +498,14 @@ def resolve_hf_chat_template(
|
||||
tokenizer_name_or_path=model_config.tokenizer,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Loading chat template fallback for %s as there isn't one "
|
||||
"defined on HF Hub.",
|
||||
tokenizer.name_or_path,
|
||||
)
|
||||
chat_template = load_chat_template(path)
|
||||
else:
|
||||
logger.debug(
|
||||
logger.debug_once(
|
||||
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user