Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-01 09:07:49 +00:00
parent 6fbeee78d1
commit bd6a340652
2 changed files with 29 additions and 38 deletions

View File

@ -3,7 +3,6 @@
import contextlib
import logging
import math
import os
import queue
import threading
import time
@ -19,8 +18,8 @@ import msgspec
import numpy as np
import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -201,7 +200,7 @@ class TransferError(MoRIIOError):
def get_moriio_mode() -> MoRIIOMode:
read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
if read_mode:
return MoRIIOMode.READ
@ -575,9 +574,9 @@ class MoRIIOWrapper:
def set_backend_type(self, backend_type):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER
post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE
num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
poll_mode = PollCqMode.POLLING
rdma_cfg = RdmaBackendConfig(
qp_per_transfer,
@ -726,7 +725,7 @@ class MoRIIOWrapper:
return
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
logger.debug("Failed to decode msgpack message, will try as string. Error: %s")
logger.debug("Failed to decode msgpack message, will try as string")
pass
try:
@ -803,15 +802,15 @@ class MoRIIOWrapper:
done_write_cache = set(self.done_write_cache_req_ids)
self.done_write_cache_req_ids = []
return done_write_cache
def shutdown(self):
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}")
logger.debug("Closed ZMQ socket for path: %s", path)
except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
self.paths.clear()
@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler:
return meta
def shutdown(self):
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug("Closed ZMQ socket for path: %s", path)
except Exception as e:
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
self.paths.clear()
def request_finished(
self,
request: "Request",
@ -1524,8 +1532,8 @@ class MoRIIOConnectorWorker:
self.block_size,
use_mla=self.use_mla,
)
#TODO: consider the integration of flashinfer or other backends.
# TODO: consider the integration of flashinfer or other backends.
self.backend_name = backend.get_name()
logger.debug("Detected attention backend %s", self.backend_name)
@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker:
index += 1
def shutdown(self):
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper:
if hasattr(self, "moriio_wrapper") and self.moriio_wrapper:
self.moriio_wrapper.shutdown()
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}")
except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
self.paths.clear()
if hasattr(self, "_handshake_initiation_executor"):
self._handshake_initiation_executor.shutdown(wait=False)
@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker:
caches_data = []
for cache_or_caches in kv_caches.values():
cache_list = (
[cache_or_caches]
if use_mla
else cache_or_caches
)
cache_list = [cache_or_caches] if use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len

View File

@ -193,10 +193,10 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False
VLLM_MORIIO_QP_PER_TRANSFER:int=1
VLLM_MORIIO_POST_BATCH_SIZE:int=-1
VLLM_MORIIO_NUM_WORKERS:int=1
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
VLLM_MORIIO_NUM_WORKERS: int = 1
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Controls the read mode for the Mori-IO connector
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower()
in ("true", "1")
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
),
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
),
# Controls the post-processing batch size for the Mori-IO connector
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
),
# Controls the number of workers for Mori operations for the Mori-IO connector
"VLLM_MORIIO_NUM_WORKERS": lambda: int(
os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")
),
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
# Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))