diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5a49b8aa1244b..12aeb4e68b61b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -10,6 +10,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass +from importlib import metadata from typing import TYPE_CHECKING, Any, Optional import msgspec @@ -41,17 +42,23 @@ Transfer = tuple[int, float] # (xfer_handle, start_time) EngineId = str ReqId = str GET_META_MSG = b"get_meta_msg" -NIXL_NUM_WORKERS = 32 +NIXL_NUM_WORKERS = 4 logger = init_logger(__name__) # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config + NIXL_VERSION = metadata.version("nixl") + NIXL_major, NIXL_minor, NIXL_patch = map(int, NIXL_VERSION.split(".")) + + except ImportError: logger.warning("NIXL is not available") NixlWrapper = None + NIXL_VERSION = None +NIXL_NUM_WORKERS class NixlAgentMetadata( msgspec.Struct, @@ -352,9 +359,8 @@ class NixlConnectorWorker: def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") - raise RuntimeError("NIXL is not available") - logger.info("Initializing NIXL wrapper") - logger.info("Initializing NIXL worker %s", engine_id) + raise RuntimeError("NIXL is not available.") + logger.info("Initializing NIXL v%s: worker %s", NIXL_VERSION, engine_id) # Config. self.vllm_config = vllm_config @@ -362,10 +368,10 @@ class NixlConnectorWorker: # Agent. import os - if os.getenv("VLLM_USE_NIXL_WORKERS", "0") == "1": - config = nixl_agent_config(num_threads=NIXL_NUM_WORKERS) - else: - config = None + NIXL_NUM_WORKERS = int(os.getenv("VLLM_NIXL_NUM_WORKERS", "1")) + logger.info(f"Using NIXL_NUM_WORKERS={NIXL_NUM_WORKERS} for NIXL agent.") + + config = nixl_agent_config(enable_prog_thread=False, num_threads=NIXL_NUM_WORKERS) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -454,7 +460,8 @@ class NixlConnectorWorker: def __del__(self): """Cleanup background threads on destruction.""" - self._handshake_initiation_executor.shutdown(wait=False) + if t_ := getattr(self, "_handshake_initiation_executor", None): + t_.shutdown(wait=False) if self._nixl_handshake_listener_t: self._nixl_handshake_listener_t.join(timeout=0)