From bd6a3406525318357068cc2854f76dbf7dade3aa Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 09:07:49 +0000 Subject: [PATCH] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 50 +++++++++---------- vllm/envs.py | 17 +++---- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 781a76baab2bf..c4ce7965c8b7a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -3,7 +3,6 @@ import contextlib import logging import math -import os import queue import threading import time @@ -19,8 +18,8 @@ import msgspec import numpy as np import torch import zmq + from vllm import envs -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -201,7 +200,7 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: - read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE + read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE logger.debug("MoRIIO Connector read_mode: %s", read_mode) if read_mode: return MoRIIOMode.READ @@ -575,9 +574,9 @@ class MoRIIOWrapper: def set_backend_type(self, backend_type): assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER - post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE - num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS + qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS poll_mode = PollCqMode.POLLING rdma_cfg = RdmaBackendConfig( qp_per_transfer, @@ -726,7 +725,7 @@ class MoRIIOWrapper: return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): - logger.debug("Failed to decode msgpack message, will try as string. Error: %s") + logger.debug("Failed to decode msgpack message, will try as string") pass try: @@ -803,15 +802,15 @@ class MoRIIOWrapper: done_write_cache = set(self.done_write_cache_req_ids) self.done_write_cache_req_ids = [] return done_write_cache - + def shutdown(self): logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") for path, sock in self.paths.items(): try: sock.close(linger=0) - logger.debug(f"Closed ZMQ socket for path: {path}") + logger.debug("Closed ZMQ socket for path: %s", path) except Exception as e: - logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) self.paths.clear() @@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler: return meta + def shutdown(self): + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear() + def request_finished( self, request: "Request", @@ -1524,8 +1532,8 @@ class MoRIIOConnectorWorker: self.block_size, use_mla=self.use_mla, ) - - #TODO: consider the integration of flashinfer or other backends. + + # TODO: consider the integration of flashinfer or other backends. self.backend_name = backend.get_name() logger.debug("Detected attention backend %s", self.backend_name) @@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker: index += 1 def shutdown(self): - if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper: + if hasattr(self, "moriio_wrapper") and self.moriio_wrapper: self.moriio_wrapper.shutdown() - - for path, sock in self.paths.items(): - try: - sock.close(linger=0) - logger.debug(f"Closed ZMQ socket for path: {path}") - except Exception as e: - logger.warning(f"Error closing ZMQ socket for path {path}: {e}") - self.paths.clear() - + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker: caches_data = [] for cache_or_caches in kv_caches.values(): - cache_list = ( - [cache_or_caches] - if use_mla - else cache_or_caches - ) + cache_list = [cache_or_caches] if use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len diff --git a/vllm/envs.py b/vllm/envs.py index f3b212dbe59cb..68bf5346fae7b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,10 +193,10 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 - VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False - VLLM_MORIIO_QP_PER_TRANSFER:int=1 - VLLM_MORIIO_POST_BATCH_SIZE:int=-1 - VLLM_MORIIO_NUM_WORKERS:int=1 + VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False + VLLM_MORIIO_QP_PER_TRANSFER: int = 1 + VLLM_MORIIO_POST_BATCH_SIZE: int = -1 + VLLM_MORIIO_NUM_WORKERS: int = 1 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Controls the read mode for the Mori-IO connector "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( - os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() - in ("true", "1") + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") ), # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") ), - # Controls the post-processing batch size for the Mori-IO connector "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") ), - # Controls the number of workers for Mori operations for the Mori-IO connector - "VLLM_MORIIO_NUM_WORKERS": lambda: int( - os.getenv("VLLM_MORIIO_NUM_WORKERS", "1") - ), + "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))