[Frontend] Add progress reporting to run_batch.py (#8060)

Co-authored-by: Adam Lugowski <adam.lugowski@parasail.io>
This commit is contained in:
Adam Lugowski 2024-09-09 11:16:37 -07:00 committed by GitHub
parent 08287ef675
commit 58fcc8545a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,11 @@
import asyncio import asyncio
from io import StringIO from io import StringIO
from typing import Awaitable, Callable, List from typing import Awaitable, Callable, List, Optional
import aiohttp import aiohttp
import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -78,6 +80,38 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar
async def read_file(path_or_url: str) -> str: async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"): if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \ async with aiohttp.ClientSession() as session, \
@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None:
async def run_request(serving_engine_func: Callable, async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput: request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body) response = await serving_engine_func(request.body)
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable,
else: else:
raise ValueError("Request must not be sent in stream mode") raise ValueError("Request must not be sent in stream mode")
tracker.completed()
return batch_output return batch_output
@ -164,6 +200,9 @@ async def main(args):
request_logger=request_logger, request_logger=request_logger,
) )
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently". # Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = [] response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"): for request_json in (await read_file(args.input_file)).strip().split("\n"):
@ -178,16 +217,19 @@ async def main(args):
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
response_futures.append( response_futures.append(
run_request(openai_serving_chat.create_chat_completion, run_request(openai_serving_chat.create_chat_completion,
request)) request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
response_futures.append( response_futures.append(
run_request(openai_serving_embedding.create_embedding, run_request(openai_serving_embedding.create_embedding, request,
request)) tracker))
tracker.submitted()
else: else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are" raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.") "supported in the batch endpoint.")
responses = await asyncio.gather(*response_futures) with tracker.pbar():
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO() output_buffer = StringIO()
for response in responses: for response in responses: