mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:34:58 +08:00
[Frontend] run-batch supports V1 (#21541)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fe56180c7f
commit
34ddcf9ff4
@ -167,7 +167,8 @@ async def run_vllm_async(
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, disable_frontend_multiprocessing
|
engine_args,
|
||||||
|
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||||
) as llm:
|
) as llm:
|
||||||
model_config = await llm.get_model_config()
|
model_config = await llm.get_model_config()
|
||||||
assert all(
|
assert all(
|
||||||
|
|||||||
@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
|
|||||||
|
|
||||||
|
|
||||||
def test_metrics_exist_run_batch(use_v1: bool):
|
def test_metrics_exist_run_batch(use_v1: bool):
|
||||||
if use_v1:
|
|
||||||
pytest.skip("Skipping test on vllm V1")
|
|
||||||
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
|
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
|
||||||
|
|
||||||
base_url = "0.0.0.0"
|
base_url = "0.0.0.0"
|
||||||
@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool):
|
|||||||
base_url,
|
base_url,
|
||||||
"--port",
|
"--port",
|
||||||
port,
|
port,
|
||||||
], )
|
],
|
||||||
|
env={"VLLM_USE_V1": "1" if use_v1 else "0"})
|
||||||
|
|
||||||
def is_server_up(url):
|
def is_server_up(url):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -148,7 +148,9 @@ async def run_vllm_async(
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, disable_frontend_multiprocessing) as llm:
|
engine_args,
|
||||||
|
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||||
|
) as llm:
|
||||||
model_config = await llm.get_model_config()
|
model_config = await llm.get_model_config()
|
||||||
assert all(
|
assert all(
|
||||||
model_config.max_model_len >= (request.prompt_len +
|
model_config.max_model_len >= (request.prompt_len +
|
||||||
|
|||||||
@ -149,6 +149,9 @@ async def lifespan(app: FastAPI):
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def build_async_engine_client(
|
async def build_async_engine_client(
|
||||||
args: Namespace,
|
args: Namespace,
|
||||||
|
*,
|
||||||
|
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
|
||||||
|
disable_frontend_multiprocessing: Optional[bool] = None,
|
||||||
client_config: Optional[dict[str, Any]] = None,
|
client_config: Optional[dict[str, Any]] = None,
|
||||||
) -> AsyncIterator[EngineClient]:
|
) -> AsyncIterator[EngineClient]:
|
||||||
|
|
||||||
@ -156,15 +159,24 @@ async def build_async_engine_client(
|
|||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
if disable_frontend_multiprocessing is None:
|
||||||
|
disable_frontend_multiprocessing = bool(
|
||||||
|
args.disable_frontend_multiprocessing)
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, args.disable_frontend_multiprocessing,
|
engine_args,
|
||||||
client_config) as engine:
|
usage_context=usage_context,
|
||||||
|
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||||
|
client_config=client_config,
|
||||||
|
) as engine:
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def build_async_engine_client_from_engine_args(
|
async def build_async_engine_client_from_engine_args(
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
|
*,
|
||||||
|
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
client_config: Optional[dict[str, Any]] = None,
|
client_config: Optional[dict[str, Any]] = None,
|
||||||
) -> AsyncIterator[EngineClient]:
|
) -> AsyncIterator[EngineClient]:
|
||||||
@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create the EngineConfig (determines if we can use V1).
|
# Create the EngineConfig (determines if we can use V1).
|
||||||
usage_context = UsageContext.OPENAI_API_SERVER
|
|
||||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||||
|
|
||||||
# V1 AsyncLLM.
|
# V1 AsyncLLM.
|
||||||
@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address,
|
|||||||
if log_config is not None:
|
if log_config is not None:
|
||||||
uvicorn_kwargs['log_config'] = log_config
|
uvicorn_kwargs['log_config'] = log_config
|
||||||
|
|
||||||
async with build_async_engine_client(args, client_config) as engine_client:
|
async with build_async_engine_client(
|
||||||
|
args,
|
||||||
|
client_config=client_config,
|
||||||
|
) as engine_client:
|
||||||
maybe_register_tokenizer_info_endpoint(args)
|
maybe_register_tokenizer_info_endpoint(args)
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from argparse import Namespace
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@ -13,10 +14,12 @@ import torch
|
|||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||||
BatchRequestOutput,
|
BatchRequestOutput,
|
||||||
BatchResponseData,
|
BatchResponseData,
|
||||||
@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable,
|
|||||||
return batch_output
|
return batch_output
|
||||||
|
|
||||||
|
|
||||||
async def main(args):
|
async def run_batch(
|
||||||
|
engine_client: EngineClient,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
args: Namespace,
|
||||||
|
) -> None:
|
||||||
if args.served_model_name is not None:
|
if args.served_model_name is not None:
|
||||||
served_model_names = args.served_model_name
|
served_model_names = args.served_model_name
|
||||||
else:
|
else:
|
||||||
served_model_names = [args.model]
|
served_model_names = [args.model]
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
||||||
engine = AsyncLLMEngine.from_engine_args(
|
|
||||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
|
||||||
|
|
||||||
model_config = await engine.get_model_config()
|
|
||||||
base_model_paths = [
|
|
||||||
BaseModelPath(name=name, model_path=args.model)
|
|
||||||
for name in served_model_names
|
|
||||||
]
|
|
||||||
|
|
||||||
if args.disable_log_requests:
|
if args.disable_log_requests:
|
||||||
request_logger = None
|
request_logger = None
|
||||||
else:
|
else:
|
||||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
|
base_model_paths = [
|
||||||
|
BaseModelPath(name=name, model_path=args.model)
|
||||||
|
for name in served_model_names
|
||||||
|
]
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
# Create the openai serving objects.
|
# Create the openai serving objects.
|
||||||
openai_serving_models = OpenAIServingModels(
|
openai_serving_models = OpenAIServingModels(
|
||||||
engine_client=engine,
|
engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
)
|
)
|
||||||
openai_serving_chat = OpenAIServingChat(
|
openai_serving_chat = OpenAIServingChat(
|
||||||
engine,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
openai_serving_models,
|
openai_serving_models,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
@ -349,7 +353,7 @@ async def main(args):
|
|||||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if "generate" in model_config.supported_tasks else None
|
) if "generate" in model_config.supported_tasks else None
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
openai_serving_models,
|
openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
@ -362,7 +366,7 @@ async def main(args):
|
|||||||
"num_labels", 0) == 1)
|
"num_labels", 0) == 1)
|
||||||
|
|
||||||
openai_serving_scores = ServingScores(
|
openai_serving_scores = ServingScores(
|
||||||
engine,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
openai_serving_models,
|
openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
@ -457,6 +461,17 @@ async def main(args):
|
|||||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args: Namespace):
|
||||||
|
async with build_async_engine_client(
|
||||||
|
args,
|
||||||
|
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||||
|
disable_frontend_multiprocessing=False,
|
||||||
|
) as engine_client:
|
||||||
|
vllm_config = await engine_client.get_vllm_config()
|
||||||
|
|
||||||
|
await run_batch(engine_client, vllm_config, args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user