Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-07-16 20:29:24 +00:00
parent 723263fa23
commit 6cd8dec23f

View File

@ -10,6 +10,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from importlib import metadata
from typing import TYPE_CHECKING, Any, Optional
import msgspec
@ -41,17 +42,23 @@ Transfer = tuple[int, float] # (xfer_handle, start_time)
EngineId = str
ReqId = str
GET_META_MSG = b"get_meta_msg"
NIXL_NUM_WORKERS = 32
NIXL_NUM_WORKERS = 4
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config
NIXL_VERSION = metadata.version("nixl")
NIXL_major, NIXL_minor, NIXL_patch = map(int, NIXL_VERSION.split("."))
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
NIXL_VERSION = None
NIXL_NUM_WORKERS
class NixlAgentMetadata(
msgspec.Struct,
@ -352,9 +359,8 @@ class NixlConnectorWorker:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
raise RuntimeError("NIXL is not available.")
logger.info("Initializing NIXL v%s: worker %s", NIXL_VERSION, engine_id)
# Config.
self.vllm_config = vllm_config
@ -362,10 +368,10 @@ class NixlConnectorWorker:
# Agent.
import os
if os.getenv("VLLM_USE_NIXL_WORKERS", "0") == "1":
config = nixl_agent_config(num_threads=NIXL_NUM_WORKERS)
else:
config = None
NIXL_NUM_WORKERS = int(os.getenv("VLLM_NIXL_NUM_WORKERS", "1"))
logger.info(f"Using NIXL_NUM_WORKERS={NIXL_NUM_WORKERS} for NIXL agent.")
config = nixl_agent_config(enable_prog_thread=False, num_threads=NIXL_NUM_WORKERS)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -454,7 +460,8 @@ class NixlConnectorWorker:
def __del__(self):
"""Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False)
if t_ := getattr(self, "_handshake_initiation_executor", None):
t_.shutdown(wait=False)
if self._nixl_handshake_listener_t:
self._nixl_handshake_listener_t.join(timeout=0)