mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
6fbeee78d1
commit
bd6a340652
@ -3,7 +3,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
@ -19,8 +18,8 @@ import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -201,7 +200,7 @@ class TransferError(MoRIIOError):
|
||||
|
||||
|
||||
def get_moriio_mode() -> MoRIIOMode:
|
||||
read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||
if read_mode:
|
||||
return MoRIIOMode.READ
|
||||
@ -575,9 +574,9 @@ class MoRIIOWrapper:
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||
post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||
num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS
|
||||
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||
poll_mode = PollCqMode.POLLING
|
||||
rdma_cfg = RdmaBackendConfig(
|
||||
qp_per_transfer,
|
||||
@ -726,7 +725,7 @@ class MoRIIOWrapper:
|
||||
|
||||
return
|
||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||
logger.debug("Failed to decode msgpack message, will try as string. Error: %s")
|
||||
logger.debug("Failed to decode msgpack message, will try as string")
|
||||
pass
|
||||
|
||||
try:
|
||||
@ -803,15 +802,15 @@ class MoRIIOWrapper:
|
||||
done_write_cache = set(self.done_write_cache_req_ids)
|
||||
self.done_write_cache_req_ids = []
|
||||
return done_write_cache
|
||||
|
||||
|
||||
def shutdown(self):
|
||||
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug(f"Closed ZMQ socket for path: {path}")
|
||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||
self.paths.clear()
|
||||
|
||||
|
||||
@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler:
|
||||
|
||||
return meta
|
||||
|
||||
def shutdown(self):
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||
self.paths.clear()
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
@ -1524,8 +1532,8 @@ class MoRIIOConnectorWorker:
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
|
||||
#TODO: consider the integration of flashinfer or other backends.
|
||||
|
||||
# TODO: consider the integration of flashinfer or other backends.
|
||||
self.backend_name = backend.get_name()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker:
|
||||
index += 1
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper:
|
||||
if hasattr(self, "moriio_wrapper") and self.moriio_wrapper:
|
||||
self.moriio_wrapper.shutdown()
|
||||
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug(f"Closed ZMQ socket for path: {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
||||
self.paths.clear()
|
||||
|
||||
|
||||
if hasattr(self, "_handshake_initiation_executor"):
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
|
||||
@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker:
|
||||
caches_data = []
|
||||
|
||||
for cache_or_caches in kv_caches.values():
|
||||
cache_list = (
|
||||
[cache_or_caches]
|
||||
if use_mla
|
||||
else cache_or_caches
|
||||
)
|
||||
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
|
||||
17
vllm/envs.py
17
vllm/envs.py
@ -193,10 +193,10 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False
|
||||
VLLM_MORIIO_QP_PER_TRANSFER:int=1
|
||||
VLLM_MORIIO_POST_BATCH_SIZE:int=-1
|
||||
VLLM_MORIIO_NUM_WORKERS:int=1
|
||||
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
|
||||
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
|
||||
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
|
||||
VLLM_MORIIO_NUM_WORKERS: int = 1
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Controls the read mode for the Mori-IO connector
|
||||
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
|
||||
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower()
|
||||
in ("true", "1")
|
||||
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
|
||||
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
|
||||
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
|
||||
),
|
||||
|
||||
# Controls the post-processing batch size for the Mori-IO connector
|
||||
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
|
||||
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
|
||||
),
|
||||
|
||||
# Controls the number of workers for Mori operations for the Mori-IO connector
|
||||
"VLLM_MORIIO_NUM_WORKERS": lambda: int(
|
||||
os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")
|
||||
),
|
||||
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
|
||||
# Controls whether or not to use cudnn prefill
|
||||
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user