Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-01 09:07:49 +00:00
parent 6fbeee78d1
commit bd6a340652
2 changed files with 29 additions and 38 deletions

View File

@ -3,7 +3,6 @@
import contextlib import contextlib
import logging import logging
import math import math
import os
import queue import queue
import threading import threading
import time import time
@ -19,8 +18,8 @@ import msgspec
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -201,7 +200,7 @@ class TransferError(MoRIIOError):
def get_moriio_mode() -> MoRIIOMode: def get_moriio_mode() -> MoRIIOMode:
read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
logger.debug("MoRIIO Connector read_mode: %s", read_mode) logger.debug("MoRIIO Connector read_mode: %s", read_mode)
if read_mode: if read_mode:
return MoRIIOMode.READ return MoRIIOMode.READ
@ -575,9 +574,9 @@ class MoRIIOWrapper:
def set_backend_type(self, backend_type): def set_backend_type(self, backend_type):
assert self.moriio_engine is not None, "MoRIIO engine must be set first" assert self.moriio_engine is not None, "MoRIIO engine must be set first"
qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
poll_mode = PollCqMode.POLLING poll_mode = PollCqMode.POLLING
rdma_cfg = RdmaBackendConfig( rdma_cfg = RdmaBackendConfig(
qp_per_transfer, qp_per_transfer,
@ -726,7 +725,7 @@ class MoRIIOWrapper:
return return
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
logger.debug("Failed to decode msgpack message, will try as string. Error: %s") logger.debug("Failed to decode msgpack message, will try as string")
pass pass
try: try:
@ -809,9 +808,9 @@ class MoRIIOWrapper:
for path, sock in self.paths.items(): for path, sock in self.paths.items():
try: try:
sock.close(linger=0) sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}") logger.debug("Closed ZMQ socket for path: %s", path)
except Exception as e: except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}") logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
self.paths.clear() self.paths.clear()
@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler:
return meta return meta
def shutdown(self):
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug("Closed ZMQ socket for path: %s", path)
except Exception as e:
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
self.paths.clear()
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",
@ -1525,7 +1533,7 @@ class MoRIIOConnectorWorker:
use_mla=self.use_mla, use_mla=self.use_mla,
) )
#TODO: consider the integration of flashinfer or other backends. # TODO: consider the integration of flashinfer or other backends.
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker:
index += 1 index += 1
def shutdown(self): def shutdown(self):
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper: if hasattr(self, "moriio_wrapper") and self.moriio_wrapper:
self.moriio_wrapper.shutdown() self.moriio_wrapper.shutdown()
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}")
except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
self.paths.clear()
if hasattr(self, "_handshake_initiation_executor"): if hasattr(self, "_handshake_initiation_executor"):
self._handshake_initiation_executor.shutdown(wait=False) self._handshake_initiation_executor.shutdown(wait=False)
@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker:
caches_data = [] caches_data = []
for cache_or_caches in kv_caches.values(): for cache_or_caches in kv_caches.values():
cache_list = ( cache_list = [cache_or_caches] if use_mla else cache_or_caches
[cache_or_caches]
if use_mla
else cache_or_caches
)
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len region_len = self.num_blocks * self.block_len

View File

@ -193,10 +193,10 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
VLLM_MORIIO_QP_PER_TRANSFER:int=1 VLLM_MORIIO_QP_PER_TRANSFER: int = 1
VLLM_MORIIO_POST_BATCH_SIZE:int=-1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1
VLLM_MORIIO_NUM_WORKERS:int=1 VLLM_MORIIO_NUM_WORKERS: int = 1
VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False
@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Controls the read mode for the Mori-IO connector # Controls the read mode for the Mori-IO connector
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
in ("true", "1")
), ),
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
), ),
# Controls the post-processing batch size for the Mori-IO connector # Controls the post-processing batch size for the Mori-IO connector
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
), ),
# Controls the number of workers for Mori operations for the Mori-IO connector # Controls the number of workers for Mori operations for the Mori-IO connector
"VLLM_MORIIO_NUM_WORKERS": lambda: int( "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")
),
# Controls whether or not to use cudnn prefill # Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL": lambda: bool( "VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))