mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 21:04:40 +08:00
[Bugfix][Frontend] Fix Issues Under High Load With zeromq Frontend (#7394)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
d3c002eadc
commit
f7e3b0c5aa
@ -86,6 +86,7 @@ steps:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/openai
|
||||
|
||||
|
||||
55
tests/entrypoints/openai/test_accuracy.py
Normal file
55
tests/entrypoints/openai/test_accuracy.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
This file test accuracy of the vLLM server via LMEval.
|
||||
It uses local-completions, which interacts with vLLM
|
||||
through the OAI API with N concurrent connections.
|
||||
This simulates real work usage of the API and makes
|
||||
sure that the zmq frontend mp RPC message passing and
|
||||
AsyncLLMEngine are working correctly.
|
||||
"""
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len", "4096", "--enable-chunked-prefill",
|
||||
"--disable-log-requests", "--enforce-eager"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_data(server):
|
||||
return {
|
||||
"url": f"{server.url_for('v1')}/completions",
|
||||
}
|
||||
|
||||
|
||||
def test_lm_eval_accuracy(server_data):
|
||||
model_args = (f"model={MODEL_NAME},"
|
||||
f"base_url={server_data['url']},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
@ -766,6 +766,11 @@ class AsyncLLMEngine:
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
@property
|
||||
def limit_concurrency(self) -> Optional[int]:
|
||||
"""Maximum number of concurrently running requests."""
|
||||
return None
|
||||
|
||||
def set_errored(self, exc: Exception) -> None:
|
||||
self._errored_with = exc
|
||||
|
||||
|
||||
@ -29,6 +29,10 @@ class AsyncEngineClient(Protocol):
|
||||
def errored(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def limit_concurrency(self) -> Optional[int]:
|
||||
"""Maximum number of concurrently running requests."""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
|
||||
@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
||||
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
||||
if engine.limit_concurrency is not None:
|
||||
logger.info(
|
||||
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
||||
"limit at the expense of performance run with "
|
||||
"--disable-frontend-multiprocessing", engine.limit_concurrency)
|
||||
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server, engine)
|
||||
|
||||
@ -135,6 +135,12 @@ async def build_async_engine_client(
|
||||
logger.info("Multiprocessing frontend to use %s for RPC Path.",
|
||||
rpc_path)
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
context = multiprocessing.get_context("spawn")
|
||||
# the current process might have CUDA context,
|
||||
@ -145,11 +151,6 @@ async def build_async_engine_client(
|
||||
rpc_server_process.start()
|
||||
logger.info("Started engine process with PID %d",
|
||||
rpc_server_process.pid)
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
@ -7,8 +7,18 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
# Success string used for RPC instructions.
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
VLLM_RPC_HEALTHY_STR = "HEALTHY"
|
||||
|
||||
# Timeouts.
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
|
||||
|
||||
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
|
||||
|
||||
# HWM is set to Infinity.
|
||||
VLLM_RPC_ZMQ_HWM = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum):
|
||||
GET_SCHEDULER_CONFIG = 5
|
||||
GET_LORA_CONFIG = 6
|
||||
DO_LOG_STATS = 7
|
||||
CHECK_HEALTH = 8
|
||||
IS_SERVER_HEALTHY = 8
|
||||
IS_TRACING_ENABLED = 9
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
@ -7,32 +9,140 @@ import zmq.asyncio
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
||||
VLLM_RPC_HEALTHY_STR,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS,
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
|
||||
VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
# yapf: enable
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
# Time to wait before checking it the server process is alive.
|
||||
SERVER_START_TIMEOUT_MS = 1000
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Path used for inprocess proxy.
|
||||
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
"""
|
||||
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
|
||||
|
||||
The overall design mirrors the Asynchronous Client Server Pattern
|
||||
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
|
||||
|
||||
On startup, the RPCClient:
|
||||
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
|
||||
via ipc, which uses unix sockets under the hood
|
||||
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
|
||||
- makes ROUTER socket (from_api_server) that binds to a random
|
||||
inproc address, which uses memory under the hood
|
||||
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
|
||||
- runs a proxy in a background asyncio task between
|
||||
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
|
||||
|
||||
Each request handled by the asyncio api_server calls generate():
|
||||
- make a DEALER socket that connects to from_api_server via inproc
|
||||
- send a RCPGenerateRequest to the inproc socket
|
||||
- background proxy forwards the request from inproc -> ipc
|
||||
- RPCServer responds to the request one token at a time over ipc
|
||||
- background proxy forwards the response from ipc -> inproc
|
||||
|
||||
The connection looks like this:
|
||||
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
|
||||
|
||||
Message routing is performed via identities that are managed by the
|
||||
ROUTER socket. ROUTER sockets track every connection it has and
|
||||
tells the caller about these. The way it tells the caller is to stick
|
||||
the connection identity in front of each message received. When we
|
||||
send the message via a ROUTER, we first send an identity frame.
|
||||
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
|
||||
for more details on connection identities.
|
||||
|
||||
This proxy design enables us to use a single unix socket, which
|
||||
improves performance by avoiding syscalls (~5%) and avoids resource limits
|
||||
such as ulimit, which defaults to 1024 on ubuntu.
|
||||
|
||||
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
|
||||
which is required to avoid dropping messages under high load.
|
||||
This is generally not advisable. However, since we are in control
|
||||
of both sides of the connection + failure on either side is
|
||||
catastrophic to the overall system health and memory profiling
|
||||
suggests limited memory overhead relative to asyncio, we will
|
||||
proceed for now.
|
||||
|
||||
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
|
||||
for more details on high water marks.
|
||||
"""
|
||||
|
||||
def __init__(self, rpc_path: str):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self.rpc_path = rpc_path
|
||||
|
||||
# Maximum number of sockets that can be opened (typically 65536).
|
||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
||||
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
|
||||
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
|
||||
raise ValueError(
|
||||
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
|
||||
"the number of concurrent requests vLLM can process. Launch "
|
||||
"vLLM with --disable-frontend-multiprocessing and open a "
|
||||
"GitHub issue so we can investigate.")
|
||||
|
||||
# We only have 1 ipc connection that uses unix sockets, so
|
||||
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
|
||||
# not run into ulimit issues)
|
||||
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
|
||||
|
||||
# IPC connection to RPC Server (uses unix sockets).
|
||||
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
|
||||
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.to_rpc_server.bind(rpc_path)
|
||||
|
||||
# In process proxy to RPC Server (uses memory-based messaging).
|
||||
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
|
||||
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.from_api_server.bind(INPROC_PROXY_PATH)
|
||||
|
||||
# Asyncio background task for the proxy.
|
||||
self.proxy_task = asyncio.create_task(
|
||||
self.run_proxy(self.from_api_server, self.to_rpc_server))
|
||||
|
||||
# Since we open 1 inproc socket per request, we have a hard cap on
|
||||
# the number of requests that can run in vLLM w. frontend
|
||||
# mulitprocessing. This value is used uvicorn to launch
|
||||
# with --limit-concurrency to return 503 when server is overloaded.
|
||||
# We need 2 sockets per request - 2:
|
||||
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
|
||||
self.limit_concurrency = socket_limit // 2 - 2
|
||||
|
||||
async def run_proxy(self, socket_from, socket_to):
|
||||
"""Background task that runs a proxy"""
|
||||
poller = zmq.asyncio.Poller()
|
||||
poller.register(socket_from, zmq.constants.POLLIN)
|
||||
poller.register(socket_to, zmq.constants.POLLIN)
|
||||
while True:
|
||||
events = await poller.poll()
|
||||
events = dict(events)
|
||||
if socket_from in events:
|
||||
identity, msg = await socket_from.recv_multipart()
|
||||
await socket_to.send_multipart([identity, msg])
|
||||
if socket_to in events:
|
||||
identity, msg = await socket_to.recv_multipart()
|
||||
await socket_from.send_multipart([identity, msg])
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
|
||||
# Wait until server is ready.
|
||||
await self.wait_for_server()
|
||||
await self._wait_for_server_rpc()
|
||||
self._errored = False
|
||||
|
||||
# Get the configs.
|
||||
@ -51,29 +161,23 @@ class AsyncEngineRPCClient:
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
# Close all sockets associated with this context and
|
||||
# then terminate the context.
|
||||
self.from_api_server.close()
|
||||
self.to_rpc_server.close()
|
||||
self.context.destroy()
|
||||
|
||||
@contextmanager
|
||||
def socket(self):
|
||||
# Ensure client sockets are always closed after use
|
||||
|
||||
# Connect to RPC socket for Request-Reply pattern,
|
||||
def to_proxy_socket(self):
|
||||
# Connect to the RPCServer via the proxy.
|
||||
# Note that we use DEALER to enable asynchronous communication
|
||||
# to enable streaming.
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
try:
|
||||
socket.connect(self.rpc_path)
|
||||
socket.connect(INPROC_PROXY_PATH)
|
||||
yield socket
|
||||
finally:
|
||||
# linger == 0 means discard unsent messages
|
||||
# when the socket is closed. This is necessary
|
||||
# because otherwise self.context.destroy() will
|
||||
# wait for 30 seconds until unsent messages are
|
||||
# received, which is impossible if the server
|
||||
# crashed. In the absence of a server crash we
|
||||
# always expect a response before closing the
|
||||
# socket anyway.
|
||||
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
|
||||
@ -81,10 +185,9 @@ class AsyncEngineRPCClient:
|
||||
error_message: str) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
with self.socket() as socket:
|
||||
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send(cloudpickle.dumps(request))
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
# Await the data from the Server.
|
||||
data = cloudpickle.loads(await socket.recv())
|
||||
@ -93,31 +196,48 @@ class AsyncEngineRPCClient:
|
||||
# LoRAConfig can be None.
|
||||
if expected_type == LoRAConfig and data is None:
|
||||
pass
|
||||
elif isinstance(data, Exception):
|
||||
logger.error(error_message)
|
||||
raise data
|
||||
else:
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
async def _send_one_way_rpc_request(self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
timeout: Optional[int] = None):
|
||||
async def _send_one_way_rpc_request(
|
||||
self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
timeout: Optional[int] = None,
|
||||
socket: Optional[zmq.asyncio.Socket] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
with self.socket() as socket:
|
||||
# Ping RPC Server with request.
|
||||
await socket.send(cloudpickle.dumps(request))
|
||||
|
||||
# Await acknowledgement from RPCServer.
|
||||
async def do_rpc_call(socket: zmq.asyncio.Socket,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
timeout=None):
|
||||
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
if timeout is not None and await socket.poll(timeout=timeout) == 0:
|
||||
raise TimeoutError(f"server didn't reply within {timeout} ms")
|
||||
raise TimeoutError(f"Server didn't reply within {timeout} ms")
|
||||
|
||||
response = cloudpickle.loads(await socket.recv())
|
||||
return cloudpickle.loads(await socket.recv())
|
||||
|
||||
# Make a new socket connection.
|
||||
if socket is None:
|
||||
with self.to_proxy_socket() as socket:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
|
||||
# Use existing socket connection.
|
||||
else:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
if isinstance(response, Exception):
|
||||
logger.error(error_message)
|
||||
raise response
|
||||
raise ValueError(error_message)
|
||||
|
||||
return response
|
||||
|
||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
@ -130,13 +250,13 @@ class AsyncEngineRPCClient:
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def wait_for_server(self):
|
||||
async def _wait_for_server_rpc(self):
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server.",
|
||||
timeout=SERVER_START_TIMEOUT_MS)
|
||||
error_message="Unable to start RPC Server",
|
||||
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
@ -184,8 +304,7 @@ class AsyncEngineRPCClient:
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.IS_TRACING_ENABLED,
|
||||
expected_type=bool,
|
||||
error_message="Could not get is_tracing_enabled flag from RPC "
|
||||
"Server")
|
||||
error_message="Could not get is_tracing_enabled from RPC Server")
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
@ -226,8 +345,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
finished = False
|
||||
try:
|
||||
with self.socket() as socket:
|
||||
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Send RPCGenerateRequest to the RPCServer.
|
||||
await socket.send_multipart([
|
||||
cloudpickle.dumps(
|
||||
@ -246,43 +364,37 @@ class AsyncEngineRPCClient:
|
||||
request_output = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
# On exception, check if the server is still healthy.
|
||||
# Use this to set the sync `is_running` and `errored`
|
||||
# properties.
|
||||
try:
|
||||
await self.check_health()
|
||||
except Exception:
|
||||
self._errored = True
|
||||
# On exception, check if the server is still healthy
|
||||
# possibly setting the `errored` property.
|
||||
if not self._errored:
|
||||
try:
|
||||
await self.check_health(socket=socket)
|
||||
except Exception as e:
|
||||
self._errored = True
|
||||
logger.exception(repr(e))
|
||||
|
||||
# NB: do before raising here so that the flag is set
|
||||
# by the time the caller receives this exception
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
|
||||
finally:
|
||||
if not finished:
|
||||
# Request was canceled by the client.
|
||||
if not finished and not self._errored:
|
||||
await self.abort(request_id)
|
||||
|
||||
async def check_health(self) -> None:
|
||||
async def check_health(self,
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
|
||||
with self.socket() as socket:
|
||||
|
||||
# Ping RPCServer with CHECK_HEALTH request.
|
||||
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
|
||||
)
|
||||
|
||||
# Await the reply from the server.
|
||||
# TODO: do we need an internal timeout here?
|
||||
# Or do we expect the external probe to timeout and let this chill?
|
||||
health_message = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if isinstance(health_message, Exception):
|
||||
raise health_message
|
||||
|
||||
if health_message != VLLM_RPC_HEALTHY_STR:
|
||||
raise ValueError("Expected healthy response from backend but got "
|
||||
"f{health_message}")
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
|
||||
error_message="Got Unhealthy response from RPC Server",
|
||||
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
socket=socket)
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from typing import Any, Coroutine
|
||||
from typing import Any, Coroutine, Union
|
||||
|
||||
import cloudpickle
|
||||
import uvloop
|
||||
@ -9,14 +9,19 @@ import zmq.asyncio
|
||||
from typing_extensions import Never
|
||||
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig]
|
||||
|
||||
|
||||
class AsyncEngineRPCServer:
|
||||
|
||||
@ -29,9 +34,10 @@ class AsyncEngineRPCServer:
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
# Init socket for readiness state.
|
||||
self.socket = self.context.socket(zmq.constants.ROUTER)
|
||||
self.socket.bind(rpc_path)
|
||||
# Init socket.
|
||||
self.socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.socket.connect(rpc_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup all resources."""
|
||||
@ -41,39 +47,27 @@ class AsyncEngineRPCServer:
|
||||
# Clear the engine reference so that it can be GC'ed.
|
||||
del self.engine
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
model_config = await self.engine.get_model_config()
|
||||
async def get_config(self, identity, request):
|
||||
try:
|
||||
config: CONFIG_TYPE
|
||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||
config = await self.engine.get_model_config()
|
||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||
config = await self.engine.get_decoding_config()
|
||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||
config = await self.engine.get_lora_config()
|
||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||
config = await self.engine.get_scheduler_config()
|
||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||
config = await self.engine.get_parallel_config()
|
||||
else:
|
||||
raise ValueError("Unknown Config Request: %s", request)
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(model_config)])
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(config)])
|
||||
|
||||
async def get_decoding_config(self, identity):
|
||||
"""Send the DecodingConfig"""
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(decoding_config)])
|
||||
|
||||
async def get_lora_config(self, identity):
|
||||
lora_config = await self.engine.get_lora_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(lora_config)])
|
||||
|
||||
async def get_scheduler_config(self, identity):
|
||||
"""Send the SchedulerConfig"""
|
||||
parallel_config = await self.engine.get_scheduler_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
|
||||
async def get_parallel_config(self, identity):
|
||||
"""Send the ParallelConfig"""
|
||||
parallel_config = await self.engine.get_parallel_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def is_tracing_enabled(self, identity):
|
||||
"""Send the is_tracing_enabled flag"""
|
||||
@ -86,31 +80,23 @@ class AsyncEngineRPCServer:
|
||||
"""Log stats and confirm success."""
|
||||
await self.engine.do_log_stats()
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
|
||||
async def is_server_ready(self, identity):
|
||||
"""Notify the client that we are ready."""
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
try:
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to abort request %s", request.request_id)
|
||||
finally:
|
||||
# Send confirmation to the client.
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
|
||||
except Exception as e:
|
||||
result = e
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
@ -127,14 +113,14 @@ class AsyncEngineRPCServer:
|
||||
[identity, cloudpickle.dumps(request_output)])
|
||||
|
||||
except Exception as e:
|
||||
### Notify client of all failures
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def check_health(self, identity):
|
||||
try:
|
||||
await self.engine.check_health()
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
@ -151,21 +137,19 @@ class AsyncEngineRPCServer:
|
||||
return self.abort(identity, request)
|
||||
|
||||
elif isinstance(request, RPCUtilityRequest):
|
||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||
return self.get_model_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||
return self.get_parallel_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||
return self.get_decoding_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||
return self.get_scheduler_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||
return self.get_lora_config(identity)
|
||||
if request in [
|
||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
||||
RPCUtilityRequest.GET_LORA_CONFIG
|
||||
]:
|
||||
return self.get_config(identity, request)
|
||||
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
||||
return self.do_log_stats(identity)
|
||||
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
||||
return self.is_server_ready(identity)
|
||||
elif request == RPCUtilityRequest.CHECK_HEALTH:
|
||||
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user