[Misc] Add request_id into benchmark_serve.py (#23065)

Signed-off-by: yangxia <yangxiast@gmail.com>
This commit is contained in:
hustxiayang 2025-08-19 04:32:18 -04:00 committed by GitHub
parent 4efd43e9b4
commit 31436e8b4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 243 additions and 46 deletions

View File

@ -34,6 +34,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: Optional[str] = None language: Optional[str] = None
request_id: Optional[str] = None
@dataclass @dataclass
@ -71,6 +72,9 @@ async def async_request_tgi(
"inputs": request_func_input.prompt, "inputs": request_func_input.prompt,
"parameters": params, "parameters": params,
} }
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
@ -82,7 +86,9 @@ async def async_request_tgi(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
@ -145,6 +151,9 @@ async def async_request_trt_llm(
} }
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len payload["min_length"] = request_func_input.output_len
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -152,7 +161,9 @@ async def async_request_trt_llm(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
"top_p": 1.0, "top_p": 1.0,
} }
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -283,6 +296,8 @@ async def async_request_openai_completions(
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -395,6 +410,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -491,6 +508,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file # Send audio file
def to_bytes(y, sr): def to_bytes(y, sr):

View File

@ -19,6 +19,7 @@ import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from io import BytesIO from io import BytesIO
@ -54,6 +55,7 @@ class SampleRequest:
expected_output_len: int expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample( def sample(
self, tokenizer: PreTrainedTokenizerBase, num_requests: int self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
) -> list[SampleRequest]: ) -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests( def maybe_oversample_requests(
self, requests: list[SampleRequest], num_requests: int self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None: ) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
Args: Args:
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests. requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, k=num_requests - len(requests)) additional = deepcopy(
random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", num_requests) logger.info("Oversampled requests to reach %d total samples.", num_requests)
@ -303,6 +319,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Enforce range_ratio < 1 # Enforce range_ratio < 1
@ -363,8 +380,10 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
) )
) )
return requests return requests
@ -406,9 +425,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
samples: list = [] samples: list = []
ind = 0
for entry in self.data: for entry in self.data:
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
@ -444,9 +465,11 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(samples, num_requests) ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
@ -512,10 +535,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False, skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
@ -534,9 +558,12 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -578,6 +605,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False, return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
@ -603,6 +631,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines] prefix_lines = self.data[:num_prefix_lines]
samples = [] samples = []
ind = 0
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices( extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines self.data, k=num_input_lines - num_prefix_lines
@ -613,14 +642,17 @@ class SonnetDataset(BenchmarkDataset):
msg, add_generation_prompt=True, tokenize=False msg, add_generation_prompt=True, tokenize=False
) )
prompt_len = len(tokenizer(prompt_formatted).input_ids) prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len: if prompt_len <= input_len:
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt_formatted if return_prompt_formatted else prompt, prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
return samples return samples
@ -672,6 +704,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int, num_requests: int,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
samples = [] samples = []
@ -693,6 +726,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
request_id=request_id_prefix + str(i),
) )
) )
return samples return samples
@ -752,12 +786,14 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in filtered_data: for item in filtered_data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
@ -785,9 +821,13 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -814,11 +854,12 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
@ -838,9 +879,12 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -870,11 +914,12 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \ prompt = f"{item['input']}\n\n{item['instruction']} Just output \
@ -892,9 +937,12 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -924,12 +972,13 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["turns"][0] prompt = item["turns"][0]
@ -947,9 +996,12 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -974,10 +1026,12 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
@ -1000,9 +1054,13 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
@ -1072,12 +1130,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt, "zed-industries/zeta": _format_zeta_prompt,
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
**kwargs,
):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
for sample in self.data: for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample) sample = formatting_prompt_func(sample)
samples.append( samples.append(
SampleRequest( SampleRequest(
@ -1086,11 +1150,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids tokenizer(sample["expected_output"]).input_ids
), ),
request_id=request_id_prefix + str(i),
) )
) )
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
@ -1139,6 +1204,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
import librosa import librosa
@ -1148,6 +1214,7 @@ class ASRDataset(HuggingFaceDataset):
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
skipped = 0 skipped = 0
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
@ -1166,8 +1233,10 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
if skipped: if skipped:
logger.warning( logger.warning(
"%d samples discarded from dataset due to" "%d samples discarded from dataset due to"
@ -1175,5 +1244,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests

View File

@ -375,11 +375,12 @@ async def benchmark(
rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
request.expected_output_len, request.expected_output_len,
request.multi_modal_data, request.multi_modal_data,
request.request_id,
) )
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
@ -397,6 +398,7 @@ async def benchmark(
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body, extra_body=extra_body,
request_id=request_id,
) )
task = limited_request_func(request_func_input=request_func_input, pbar=pbar) task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
tasks.append(asyncio.create_task(task)) tasks.append(asyncio.create_task(task))
@ -665,6 +667,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
@ -678,6 +681,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=False, return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
) )
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
@ -690,6 +694,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=True, return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
@ -751,6 +756,7 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
) )
else: else:
@ -762,10 +768,15 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": lambda: BurstGPTDataset( "burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed, dataset_path=args.dataset_path
).sample(tokenizer=tokenizer, num_requests=args.num_prompts), ).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -773,6 +784,7 @@ def main(args: argparse.Namespace):
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
), ),
} }
@ -1118,6 +1130,13 @@ def create_argument_parser():
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve", "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
) )
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options") custom_group = parser.add_argument_group("custom dataset options")

View File

@ -18,6 +18,7 @@ import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from io import BytesIO from io import BytesIO
@ -76,6 +77,7 @@ class SampleRequest:
Union[MultiModalDataDict, dict, list[dict]] Union[MultiModalDataDict, dict, list[dict]]
] = None ] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -183,7 +185,8 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]: num_requests: int,
request_id_prefix: str = "") -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
@ -194,6 +197,8 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
@ -201,8 +206,12 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(
num_requests: int) -> None: self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
@ -211,11 +220,17 @@ class BenchmarkDataset(ABC):
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. requests.
num_requests (int): The target number of requests. num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = deepcopy(
k=num_requests - len(requests)) random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.",
num_requests) num_requests)
@ -334,6 +349,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Enforce range_ratio < 1 # Enforce range_ratio < 1
@ -391,6 +407,7 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)) ))
return requests return requests
@ -432,9 +449,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
samples: list = [] samples: list = []
ind = 0
for entry in self.data: for entry in self.data:
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
@ -470,8 +489,10 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(samples, num_requests) ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
@ -647,6 +668,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
@ -660,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=False, return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
) )
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
@ -671,6 +694,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=True, return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
@ -730,6 +754,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
) )
else: else:
@ -741,11 +766,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": "burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed, lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path). dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts), sample(tokenizer=tokenizer, num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,),
"random": "random":
lambda: RandomDataset(random_seed=args.seed, lambda: RandomDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample( dataset_path=args.dataset_path).sample(
@ -755,6 +782,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
), ),
"prefix_repetition": "prefix_repetition":
lambda: PrefixRepetitionRandomDataset( lambda: PrefixRepetitionRandomDataset(
@ -766,6 +794,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
suffix_len=args.prefix_repetition_suffix_len, suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes, num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len, output_len=args.prefix_repetition_output_len,
request_id_prefix=args.request_id_prefix,
), ),
} }
@ -839,10 +868,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False, skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
@ -864,8 +894,10 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -909,6 +941,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False, return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
@ -934,6 +967,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines] prefix_lines = self.data[:num_prefix_lines]
samples = [] samples = []
ind = 0
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices(self.data, extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines) k=num_input_lines - num_prefix_lines)
@ -949,7 +983,9 @@ class SonnetDataset(BenchmarkDataset):
if return_prompt_formatted else prompt, if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
)) ))
ind += 1
return samples return samples
@ -1000,6 +1036,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int, num_requests: int,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
samples = [] samples = []
@ -1020,6 +1057,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
request_id=request_id_prefix + str(i),
)) ))
return samples return samples
@ -1075,11 +1113,13 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter( filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2) lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
ind = 0
dynamic_output = output_len is None dynamic_output = output_len is None
for item in filtered_data: for item in filtered_data:
@ -1111,8 +1151,11 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1141,12 +1184,13 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
@ -1168,8 +1212,10 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1198,11 +1244,12 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \ prompt = f"{item['input']}\n\n{item['instruction']} Just output \
@ -1224,8 +1271,10 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1255,13 +1304,14 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["turns"][0] prompt = item["turns"][0]
@ -1282,8 +1332,10 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1305,8 +1357,10 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
sampled_requests = [] sampled_requests = []
ind = 0
dynamic_output = output_len is None dynamic_output = output_len is None
for item in self.data: for item in self.data:
@ -1331,8 +1385,12 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1403,13 +1461,14 @@ class NextEditPredictionDataset(HuggingFaceDataset):
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
request_id_prefix: str = "",
**kwargs): **kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path) self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
for sample in self.data: for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample) sample = formatting_prompt_func(sample)
samples.append( samples.append(
SampleRequest( SampleRequest(
@ -1417,10 +1476,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt_len=len(tokenizer(sample["prompt"]).input_ids), prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids), tokenizer(sample["expected_output"]).input_ids),
request_id=request_id_prefix + str(i),
)) ))
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
@ -1470,6 +1530,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
@ -1477,6 +1538,7 @@ class ASRDataset(HuggingFaceDataset):
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
ind = 0
skipped = 0 skipped = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
@ -1496,7 +1558,9 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
ind += 1
if skipped: if skipped:
logger.warning( logger.warning(
"%d samples discarded from dataset due to" "%d samples discarded from dataset due to"
@ -1504,7 +1568,8 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1541,11 +1606,13 @@ class MLPerfDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Force dynamic output length based on reference completion. # Force dynamic output length based on reference completion.
dynamic_output = output_len is None dynamic_output = output_len is None
sampled_requests: list[SampleRequest] = [] sampled_requests: list[SampleRequest] = []
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
@ -1580,10 +1647,13 @@ class MLPerfDataset(HuggingFaceDataset):
prompt=prompt_formatted, prompt=prompt_formatted,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=expected_output_len, expected_output_len=expected_output_len,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
@ -1616,6 +1686,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
suffix_len: int = DEFAULT_SUFFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES, num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size

View File

@ -31,6 +31,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: Optional[str] = None language: Optional[str] = None
request_id: Optional[str] = None
@dataclass @dataclass
@ -87,6 +88,8 @@ async def async_request_openai_completions(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -210,6 +213,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -311,6 +316,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file # Send audio file
def to_bytes(y, sr): def to_bytes(y, sr):

View File

@ -478,11 +478,12 @@ async def benchmark(
"timestamp": timestamp "timestamp": timestamp
}) })
last_int_rps = current_int_rps last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
request.expected_output_len, request.expected_output_len,
request.multi_modal_data, request.multi_modal_data,
request.request_id,
) )
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
@ -498,7 +499,8 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
request_id=request_id,)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
@ -865,6 +867,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve", "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
) )
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument( sampling_group.add_argument(