[ BugFix ] Move zmq frontend to IPC instead of TCP (#7222)

This commit is contained in:
Robert Shaw 2024-08-07 12:24:56 -04:00 committed by GitHub
parent 0f7052bc7e
commit 564985729a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 22 deletions

View File

@ -43,7 +43,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -106,16 +106,20 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
else: else:
# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.",
rpc_path)
# Start RPCServer in separate process (holds the AsyncLLMEngine). # Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server, rpc_server_process = Process(target=run_rpc_server,
args=(engine_args, args=(engine_args,
UsageContext.OPENAI_API_SERVER, UsageContext.OPENAI_API_SERVER,
port)) rpc_path))
rpc_server_process.start() rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port) async_engine_client = AsyncEngineRPCClient(rpc_path)
await async_engine_client.setup() await async_engine_client.setup()
try: try:

View File

@ -21,9 +21,9 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
class AsyncEngineRPCClient: class AsyncEngineRPCClient:
def __init__(self, port: int): def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}" self.rpc_path = rpc_path
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Setup the client before it starts sending server requests."""
@ -58,7 +58,7 @@ class AsyncEngineRPCClient:
# to enable streaming. # to enable streaming.
socket = self.context.socket(zmq.constants.DEALER) socket = self.context.socket(zmq.constants.DEALER)
try: try:
socket.connect(self.path) socket.connect(self.rpc_path)
yield socket yield socket
finally: finally:
socket.close() socket.close()

View File

@ -20,7 +20,7 @@ logger = init_logger(__name__)
class AsyncEngineRPCServer: class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs, def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int): usage_context: UsageContext, rpc_path: str):
# Initialize engine first. # Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context) usage_context)
@ -30,9 +30,7 @@ class AsyncEngineRPCServer:
# Init socket for readiness state. # Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER) self.socket = self.context.socket(zmq.constants.ROUTER)
# Note numeric form of localhost should be used for zmq bind(), self.socket.bind(rpc_path)
# see https://stackoverflow.com/a/8958414
self.socket.bind(f"tcp://127.0.0.1:{port}")
def cleanup(self): def cleanup(self):
"""Cleanup all resources.""" """Cleanup all resources."""
@ -213,6 +211,6 @@ async def run_server(server: AsyncEngineRPCServer):
def run_rpc_server(async_engine_args: AsyncEngineArgs, def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int): usage_context: UsageContext, rpc_path: str):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port) server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
asyncio.run(run_server(server)) asyncio.run(run_server(server))

View File

@ -1,10 +1,11 @@
import os import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570 VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None VLLM_INSTANCE_ID: Optional[str] = None
@ -142,10 +143,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0')) lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None, if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode, # path used for ipc when the frontend api server is running in
# to communicate with the backend engine process over ZMQ. # multi-processing mode to communicate with the backend engine process.
'VLLM_RPC_PORT': 'VLLM_RPC_BASE_PATH':
lambda: int(os.getenv('VLLM_RPC_PORT', '5570')), lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()),
# If true, will load models from ModelScope instead of Hugging Face Hub. # If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers # note that the value is true or false, not numbers

View File

@ -19,6 +19,7 @@ from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union, overload) Union, overload)
from uuid import uuid4
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -484,10 +485,13 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port(port: Optional[int] = None) -> int: def get_open_zmq_ipc_path() -> str:
if port is None: base_rpc_path = envs.VLLM_RPC_BASE_PATH
# Default behavior here is to return a port for multi-gpu communication return f"ipc://{base_rpc_path}/{uuid4()}"
port = envs.VLLM_PORT
def get_open_port() -> int:
port = envs.VLLM_PORT
if port is not None: if port is not None:
while True: while True:
try: try: