mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 06:09:10 +08:00
[Frontend] Add progress reporting to run_batch.py (#8060)
Co-authored-by: Adam Lugowski <adam.lugowski@parasail.io>
This commit is contained in:
parent
08287ef675
commit
58fcc8545a
@ -1,9 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Awaitable, Callable, List
|
from typing import Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import torch
|
||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@ -78,6 +80,38 @@ def parse_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# explicitly use pure text format, with a newline at the end
|
||||||
|
# this makes it impossible to see the animation in the progress bar
|
||||||
|
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||||
|
# each line of output with some prefix.
|
||||||
|
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProgressTracker:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._total = 0
|
||||||
|
self._pbar: Optional[tqdm] = None
|
||||||
|
|
||||||
|
def submitted(self):
|
||||||
|
self._total += 1
|
||||||
|
|
||||||
|
def completed(self):
|
||||||
|
if self._pbar:
|
||||||
|
self._pbar.update()
|
||||||
|
|
||||||
|
def pbar(self) -> tqdm:
|
||||||
|
enable_tqdm = not torch.distributed.is_initialized(
|
||||||
|
) or torch.distributed.get_rank() == 0
|
||||||
|
self._pbar = tqdm(total=self._total,
|
||||||
|
unit="req",
|
||||||
|
desc="Running batch",
|
||||||
|
mininterval=5,
|
||||||
|
disable=not enable_tqdm,
|
||||||
|
bar_format=_BAR_FORMAT)
|
||||||
|
return self._pbar
|
||||||
|
|
||||||
|
|
||||||
async def read_file(path_or_url: str) -> str:
|
async def read_file(path_or_url: str) -> str:
|
||||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||||
async with aiohttp.ClientSession() as session, \
|
async with aiohttp.ClientSession() as session, \
|
||||||
@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def run_request(serving_engine_func: Callable,
|
async def run_request(serving_engine_func: Callable,
|
||||||
request: BatchRequestInput) -> BatchRequestOutput:
|
request: BatchRequestInput,
|
||||||
|
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||||
response = await serving_engine_func(request.body)
|
response = await serving_engine_func(request.body)
|
||||||
|
|
||||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||||
@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable,
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Request must not be sent in stream mode")
|
raise ValueError("Request must not be sent in stream mode")
|
||||||
|
|
||||||
|
tracker.completed()
|
||||||
return batch_output
|
return batch_output
|
||||||
|
|
||||||
|
|
||||||
@ -164,6 +200,9 @@ async def main(args):
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tracker = BatchProgressTracker()
|
||||||
|
logger.info("Reading batch from %s...", args.input_file)
|
||||||
|
|
||||||
# Submit all requests in the file to the engine "concurrently".
|
# Submit all requests in the file to the engine "concurrently".
|
||||||
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
||||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||||
@ -178,16 +217,19 @@ async def main(args):
|
|||||||
if request.url == "/v1/chat/completions":
|
if request.url == "/v1/chat/completions":
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(openai_serving_chat.create_chat_completion,
|
run_request(openai_serving_chat.create_chat_completion,
|
||||||
request))
|
request, tracker))
|
||||||
|
tracker.submitted()
|
||||||
elif request.url == "/v1/embeddings":
|
elif request.url == "/v1/embeddings":
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(openai_serving_embedding.create_embedding,
|
run_request(openai_serving_embedding.create_embedding, request,
|
||||||
request))
|
tracker))
|
||||||
|
tracker.submitted()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
|
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
|
||||||
"supported in the batch endpoint.")
|
"supported in the batch endpoint.")
|
||||||
|
|
||||||
responses = await asyncio.gather(*response_futures)
|
with tracker.pbar():
|
||||||
|
responses = await asyncio.gather(*response_futures)
|
||||||
|
|
||||||
output_buffer = StringIO()
|
output_buffer = StringIO()
|
||||||
for response in responses:
|
for response in responses:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user