diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 770560a5e549e..8c840fd2ac7e0 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import importlib import os import signal import time import uuid from dataclasses import dataclass from threading import Thread +from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock @@ -24,7 +26,11 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient +from vllm.v1.engine.core_client import ( + AsyncMPClient, + EngineCoreClient, + SyncMPClient, +) from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.abstract import Executor @@ -60,6 +66,91 @@ def make_request( ) +def _reload_envs_module(): + import vllm.envs as envs_mod + + cache_clear = getattr(getattr(envs_mod, "__getattr__", None), "cache_clear", None) + if cache_clear is not None: + cache_clear() + return importlib.reload(envs_mod) + + +def _reload_core_client_module(): + module = importlib.import_module("vllm.v1.engine.core_client") + return importlib.reload(module) + + +def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch): + timeout_value = 654 + monkeypatch.setenv("VLLM_ENGINE_READY_TIMEOUT_S", str(timeout_value)) + + # Ensure that the environment variable is loaded if caching is enabled + _reload_envs_module() + core_client_mod = _reload_core_client_module() + + poll_timeouts: list[int] = [] + + class ShadowSocket: + def poll(self, timeout: int) -> int: + # Capture the timeout value for each poll call + poll_timeouts.append(timeout) + return 1 + + def recv_multipart(self): + return (b"\x00\x00", b"ready") + + class DummySocket: + def send_multipart(self, _msg, *, copy: bool = False, track: bool = False): + if track: + return SimpleNamespace(done=True) + + def recv_multipart(self, *, copy: bool = False): + return (b"", b"") + + def close(self, *, linger: int = 0): + pass + + def bind(self, _address): + pass + + def connect(self, _address): + pass + + def setsockopt(self, *_args, **_kwargs): + pass + + monkeypatch.setattr(core_client_mod.zmq.Socket, "shadow", lambda *_: ShadowSocket()) + monkeypatch.setattr( + core_client_mod, "make_zmq_socket", lambda *_, **__: DummySocket() + ) + + parallel_config = SimpleNamespace( + data_parallel_size=1, + data_parallel_rank=0, + data_parallel_size_local=1, + data_parallel_rank_local=None, + data_parallel_hybrid_lb=False, + data_parallel_external_lb=False, + ) + vllm_config = SimpleNamespace(parallel_config=parallel_config) + + client = core_client_mod.MPClient( + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=object, + log_stats=False, + client_addresses={ + "input_address": "inproc://input", + "output_address": "inproc://output", + }, + ) + try: + # timeout_value is in seconds, but poll receives milliseconds + assert poll_timeouts == [timeout_value * 1000] + finally: + client.shutdown() + + def loop_until_done(client: EngineCoreClient, outputs: dict): while True: engine_core_outputs = client.get_output().outputs diff --git a/vllm/envs.py b/vllm/envs.py index f6db42e9124d6..1d4128d74b95c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: str | None = None VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_ENGINE_READY_TIMEOUT_S: int = 600 VLLM_API_KEY: str | None = None VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False S3_ACCESS_KEY_ID: str | None = None @@ -604,6 +605,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int( os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60") ), + # Timeout in seconds for waiting for engine cores to become ready + # during startup. Default is 600 seconds (10 minutes). + "VLLM_ENGINE_READY_TIMEOUT_S": lambda: int( + os.environ.get("VLLM_ENGINE_READY_TIMEOUT_S", "600") + ), # API key for vLLM API server "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), # Whether to log responses from API Server for debugging diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 807db8275fbf5..cacbc805e84f8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -20,6 +20,7 @@ import zmq import zmq.asyncio from vllm.config import VllmConfig +from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask @@ -528,7 +529,9 @@ class MPClient(EngineCoreClient): identities = set(self.core_engines) sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: - if not sync_input_socket.poll(timeout=600_000): + if not sync_input_socket.poll( + timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms + ): raise TimeoutError( "Timed out waiting for engines to send" "initial message on input socket." @@ -1340,7 +1343,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # Wait for ready messages from new engines on the input socket sync_input_socket = zmq.Socket.shadow(self.input_socket) while new_engine_identities: - if not sync_input_socket.poll(timeout=600_000): + if not sync_input_socket.poll( + timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms + ): raise TimeoutError( "Timed out waiting for new engines to send initial " "message on input socket."