From 10f535c0863ffcd23ae5838979ea4e0dd82c5886 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Thu, 21 Aug 2025 10:22:18 -0700 Subject: [PATCH] [Bugfix] Fix port conflict by obtaining a list of open ports upfront (#21894) Signed-off-by: Ming Yang --- vllm/config/parallel.py | 24 ++++++++++++++++++------ vllm/utils/__init__.py | 20 ++++++++++++++------ vllm/v1/engine/utils.py | 4 +++- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 2b716a77066ac..f667cac2fe02a 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -15,7 +15,7 @@ import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_port +from vllm.utils import cuda_device_count_stateless, get_open_ports_list if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv @@ -171,6 +171,11 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" + _data_parallel_master_port_list: list[int] = field(default_factory=list) + """List of open port auto-queried for data parallel messaging. + Set to be private as it's not intended to be configured by users. + """ + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world @@ -183,11 +188,15 @@ class ParallelConfig: processes that is related to data parallelism, e.g. both in the worker and in the engine, which can live in different processes. To avoid port conflicts, we - increment the port number each time we need to initialize a - new process group related to data parallelism. + pop a new port from the prepared port list each time we need to + initialize a new process group related to data parallelism. """ - answer = self.data_parallel_master_port - self.data_parallel_master_port += 1 + if self._data_parallel_master_port_list: + answer = self._data_parallel_master_port_list.pop() + else: + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer def stateless_init_dp_group(self) -> ProcessGroup: @@ -313,7 +322,10 @@ class ParallelConfig: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. - self.data_parallel_master_port = get_open_port() + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = \ + self._data_parallel_master_port_list.pop() if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5cb9f97ae0b08..1eefb32eaa90b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -516,8 +516,8 @@ def random_uuid() -> str: class AsyncMicrobatchTokenizer: """Asynchronous tokenizer with micro-batching. - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used so the event loop stays responsive. """ @@ -664,18 +664,18 @@ class AsyncMicrobatchTokenizer: def _queue_key(self, op: str, kwargs: dict) -> tuple: """ Return a normalized key describing operation + kwargs. - + - `add_special_tokens`: {True/False} - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), + - If `truncation` is False (`max_length` is None), returns a key for a can_batch queue. - If `truncation` is True and `max_length` is None or equals `tokenizer.model_max_length`, returns a key for a can_batch queue. - Otherwise, returns a key for a cannot_batch queue. - + Examples: - Decode: ("decode",) - - Encode typical: + - Encode typical: ("encode", add_special_tokens, bool_truncation, max_length_label) - Fallback: ("encode", "other") """ @@ -940,6 +940,14 @@ def get_open_port() -> int: return _get_open_port() +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 770aa7d9dcc8a..62f229e286931 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -71,7 +71,7 @@ class EngineHandshakeMetadata: connect to. """ addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] + parallel_config: dict[str, Union[int, str, list[int]]] class CoreEngineProcManager: @@ -798,6 +798,8 @@ def wait_for_engine_startup( parallel_config.data_parallel_master_ip, "data_parallel_master_port": parallel_config.data_parallel_master_port, + "_data_parallel_master_port_list": + parallel_config._data_parallel_master_port_list, "data_parallel_size": parallel_config.data_parallel_size, }))