[Frontend] run-batch supports V1 (#21541)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-25 11:05:55 +08:00 committed by GitHub
parent fe56180c7f
commit 34ddcf9ff4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 25 deletions

View File

@ -167,7 +167,8 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm: ) as llm:
model_config = await llm.get_model_config() model_config = await llm.get_model_config()
assert all( assert all(

View File

@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
def test_metrics_exist_run_batch(use_v1: bool): def test_metrics_exist_run_batch(use_v1: bool):
if use_v1:
pytest.skip("Skipping test on vllm V1")
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
base_url = "0.0.0.0" base_url = "0.0.0.0"
@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool):
base_url, base_url,
"--port", "--port",
port, port,
], ) ],
env={"VLLM_USE_V1": "1" if use_v1 else "0"})
def is_server_up(url): def is_server_up(url):
try: try:

View File

@ -148,7 +148,9 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm: engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = await llm.get_model_config() model_config = await llm.get_model_config()
assert all( assert all(
model_config.max_model_len >= (request.prompt_len + model_config.max_model_len >= (request.prompt_len +

View File

@ -149,6 +149,9 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace, args: Namespace,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: Optional[bool] = None,
client_config: Optional[dict[str, Any]] = None, client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
@ -156,15 +159,24 @@ async def build_async_engine_client(
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(
args.disable_frontend_multiprocessing)
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing, engine_args,
client_config) as engine: usage_context=usage_context,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
client_config=client_config,
) as engine:
yield engine yield engine
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client_from_engine_args( async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None, client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args(
""" """
# Create the EngineConfig (determines if we can use V1). # Create the EngineConfig (determines if we can use V1).
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM. # V1 AsyncLLM.
@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address,
if log_config is not None: if log_config is not None:
uvicorn_kwargs['log_config'] = log_config uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client: async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
maybe_register_tokenizer_info_endpoint(args) maybe_register_tokenizer_info_endpoint(args)
app = build_app(args) app = build_app(args)

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import tempfile import tempfile
from argparse import Namespace
from collections.abc import Awaitable from collections.abc import Awaitable
from http import HTTPStatus from http import HTTPStatus
from io import StringIO from io import StringIO
@ -13,10 +14,12 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput, BatchRequestOutput,
BatchResponseData, BatchResponseData,
@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable,
return batch_output return batch_output
async def main(args): async def run_batch(
engine_client: EngineClient,
vllm_config: VllmConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
else: else:
request_logger = RequestLogger(max_log_len=args.max_log_len) request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
model_config = vllm_config.model_config
# Create the openai serving objects. # Create the openai serving objects.
openai_serving_models = OpenAIServingModels( openai_serving_models = OpenAIServingModels(
engine_client=engine, engine_client=engine_client,
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
) )
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
args.response_role, args.response_role,
@ -349,7 +353,7 @@ async def main(args):
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None ) if "generate" in model_config.supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
@ -362,7 +366,7 @@ async def main(args):
"num_labels", 0) == 1) "num_labels", 0) == 1)
openai_serving_scores = ServingScores( openai_serving_scores = ServingScores(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
@ -457,6 +461,17 @@ async def main(args):
await write_file(args.output_file, responses, args.output_tmp_dir) await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
vllm_config = await engine_client.get_vllm_config()
await run_batch(engine_client, vllm_config, args)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()