diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 14461121fece..c0a7f1d58250 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -167,7 +167,8 @@ async def run_vllm_async( from vllm import SamplingParams async with build_async_engine_client_from_engine_args( - engine_args, disable_frontend_multiprocessing + engine_args, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, ) as llm: model_config = await llm.get_model_config() assert all( diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 2d7b845736b8..9107d089834b 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer, def test_metrics_exist_run_batch(use_v1: bool): - if use_v1: - pytest.skip("Skipping test on vllm V1") input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 base_url = "0.0.0.0" @@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool): base_url, "--port", port, - ], ) + ], + env={"VLLM_USE_V1": "1" if use_v1 else "0"}) def is_server_up(url): try: diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index af2ca9657128..0fe042e2736d 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -148,7 +148,9 @@ async def run_vllm_async( from vllm import SamplingParams async with build_async_engine_client_from_engine_args( - engine_args, disable_frontend_multiprocessing) as llm: + engine_args, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, + ) as llm: model_config = await llm.get_model_config() assert all( model_config.max_model_len >= (request.prompt_len + diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ba257990d4a4..8540d25d4e94 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -149,6 +149,9 @@ async def lifespan(app: FastAPI): @asynccontextmanager async def build_async_engine_client( args: Namespace, + *, + usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, + disable_frontend_multiprocessing: Optional[bool] = None, client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: @@ -156,15 +159,24 @@ async def build_async_engine_client( # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) + if disable_frontend_multiprocessing is None: + disable_frontend_multiprocessing = bool( + args.disable_frontend_multiprocessing) + async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing, - client_config) as engine: + engine_args, + usage_context=usage_context, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, + client_config=client_config, + ) as engine: yield engine @asynccontextmanager async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, + *, + usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, disable_frontend_multiprocessing: bool = False, client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: @@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args( """ # Create the EngineConfig (determines if we can use V1). - usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) # V1 AsyncLLM. @@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address, if log_config is not None: uvicorn_kwargs['log_config'] = log_config - async with build_async_engine_client(args, client_config) as engine_client: + async with build_async_engine_client( + args, + client_config=client_config, + ) as engine_client: maybe_register_tokenizer_info_endpoint(args) app = build_app(args) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index ef5bf6f9a812..577055092327 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -3,6 +3,7 @@ import asyncio import tempfile +from argparse import Namespace from collections.abc import Awaitable from http import HTTPStatus from io import StringIO @@ -13,10 +14,12 @@ import torch from prometheus_client import start_http_server from tqdm import tqdm +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs, optional_type -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf: disable +from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, @@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable, return batch_output -async def main(args): +async def run_batch( + engine_client: EngineClient, + vllm_config: VllmConfig, + args: Namespace, +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) - - model_config = await engine.get_model_config() - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names - ] - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + model_config = vllm_config.model_config + # Create the openai serving objects. openai_serving_models = OpenAIServingModels( - engine_client=engine, + engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, lora_modules=None, ) openai_serving_chat = OpenAIServingChat( - engine, + engine_client, model_config, openai_serving_models, args.response_role, @@ -349,7 +353,7 @@ async def main(args): enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if "generate" in model_config.supported_tasks else None openai_serving_embedding = OpenAIServingEmbedding( - engine, + engine_client, model_config, openai_serving_models, request_logger=request_logger, @@ -362,7 +366,7 @@ async def main(args): "num_labels", 0) == 1) openai_serving_scores = ServingScores( - engine, + engine_client, model_config, openai_serving_models, request_logger=request_logger, @@ -457,6 +461,17 @@ async def main(args): await write_file(args.output_file, responses, args.output_tmp_dir) +async def main(args: Namespace): + async with build_async_engine_client( + args, + usage_context=UsageContext.OPENAI_BATCH_RUNNER, + disable_frontend_multiprocessing=False, + ) as engine_client: + vllm_config = await engine_client.get_vllm_config() + + await run_batch(engine_client, vllm_config, args) + + if __name__ == "__main__": args = parse_args()