mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 22:46:31 +08:00
Make engine core client handshake timeout configurable (#27444)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
parent
969bbc7c61
commit
1ab5213531
@ -2,12 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -24,7 +26,11 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
|
||||
from vllm.v1.engine.core_client import (
|
||||
AsyncMPClient,
|
||||
EngineCoreClient,
|
||||
SyncMPClient,
|
||||
)
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@ -60,6 +66,91 @@ def make_request(
|
||||
)
|
||||
|
||||
|
||||
def _reload_envs_module():
|
||||
import vllm.envs as envs_mod
|
||||
|
||||
cache_clear = getattr(getattr(envs_mod, "__getattr__", None), "cache_clear", None)
|
||||
if cache_clear is not None:
|
||||
cache_clear()
|
||||
return importlib.reload(envs_mod)
|
||||
|
||||
|
||||
def _reload_core_client_module():
|
||||
module = importlib.import_module("vllm.v1.engine.core_client")
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
timeout_value = 654
|
||||
monkeypatch.setenv("VLLM_ENGINE_READY_TIMEOUT_S", str(timeout_value))
|
||||
|
||||
# Ensure that the environment variable is loaded if caching is enabled
|
||||
_reload_envs_module()
|
||||
core_client_mod = _reload_core_client_module()
|
||||
|
||||
poll_timeouts: list[int] = []
|
||||
|
||||
class ShadowSocket:
|
||||
def poll(self, timeout: int) -> int:
|
||||
# Capture the timeout value for each poll call
|
||||
poll_timeouts.append(timeout)
|
||||
return 1
|
||||
|
||||
def recv_multipart(self):
|
||||
return (b"\x00\x00", b"ready")
|
||||
|
||||
class DummySocket:
|
||||
def send_multipart(self, _msg, *, copy: bool = False, track: bool = False):
|
||||
if track:
|
||||
return SimpleNamespace(done=True)
|
||||
|
||||
def recv_multipart(self, *, copy: bool = False):
|
||||
return (b"", b"")
|
||||
|
||||
def close(self, *, linger: int = 0):
|
||||
pass
|
||||
|
||||
def bind(self, _address):
|
||||
pass
|
||||
|
||||
def connect(self, _address):
|
||||
pass
|
||||
|
||||
def setsockopt(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(core_client_mod.zmq.Socket, "shadow", lambda *_: ShadowSocket())
|
||||
monkeypatch.setattr(
|
||||
core_client_mod, "make_zmq_socket", lambda *_, **__: DummySocket()
|
||||
)
|
||||
|
||||
parallel_config = SimpleNamespace(
|
||||
data_parallel_size=1,
|
||||
data_parallel_rank=0,
|
||||
data_parallel_size_local=1,
|
||||
data_parallel_rank_local=None,
|
||||
data_parallel_hybrid_lb=False,
|
||||
data_parallel_external_lb=False,
|
||||
)
|
||||
vllm_config = SimpleNamespace(parallel_config=parallel_config)
|
||||
|
||||
client = core_client_mod.MPClient(
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=object,
|
||||
log_stats=False,
|
||||
client_addresses={
|
||||
"input_address": "inproc://input",
|
||||
"output_address": "inproc://output",
|
||||
},
|
||||
)
|
||||
try:
|
||||
# timeout_value is in seconds, but poll receives milliseconds
|
||||
assert poll_timeouts == [timeout_value * 1000]
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
||||
LOCAL_RANK: int = 0
|
||||
CUDA_VISIBLE_DEVICES: str | None = None
|
||||
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
|
||||
VLLM_ENGINE_READY_TIMEOUT_S: int = 600
|
||||
VLLM_API_KEY: str | None = None
|
||||
VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False
|
||||
S3_ACCESS_KEY_ID: str | None = None
|
||||
@ -604,6 +605,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int(
|
||||
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")
|
||||
),
|
||||
# Timeout in seconds for waiting for engine cores to become ready
|
||||
# during startup. Default is 600 seconds (10 minutes).
|
||||
"VLLM_ENGINE_READY_TIMEOUT_S": lambda: int(
|
||||
os.environ.get("VLLM_ENGINE_READY_TIMEOUT_S", "600")
|
||||
),
|
||||
# API key for vLLM API server
|
||||
"VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None),
|
||||
# Whether to log responses from API Server for debugging
|
||||
|
||||
@ -20,6 +20,7 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
@ -528,7 +529,9 @@ class MPClient(EngineCoreClient):
|
||||
identities = set(self.core_engines)
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
while identities:
|
||||
if not sync_input_socket.poll(timeout=600_000):
|
||||
if not sync_input_socket.poll(
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for engines to send"
|
||||
"initial message on input socket."
|
||||
@ -1340,7 +1343,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
# Wait for ready messages from new engines on the input socket
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
while new_engine_identities:
|
||||
if not sync_input_socket.poll(timeout=600_000):
|
||||
if not sync_input_socket.poll(
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for new engines to send initial "
|
||||
"message on input socket."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user