mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 17:37:08 +08:00
format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
6fbeee78d1
commit
bd6a340652
@ -3,7 +3,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -19,8 +18,8 @@ import msgspec
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
@ -201,7 +200,7 @@ class TransferError(MoRIIOError):
|
|||||||
|
|
||||||
|
|
||||||
def get_moriio_mode() -> MoRIIOMode:
|
def get_moriio_mode() -> MoRIIOMode:
|
||||||
read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||||
if read_mode:
|
if read_mode:
|
||||||
return MoRIIOMode.READ
|
return MoRIIOMode.READ
|
||||||
@ -575,9 +574,9 @@ class MoRIIOWrapper:
|
|||||||
|
|
||||||
def set_backend_type(self, backend_type):
|
def set_backend_type(self, backend_type):
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER
|
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||||
post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE
|
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||||
num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS
|
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||||
poll_mode = PollCqMode.POLLING
|
poll_mode = PollCqMode.POLLING
|
||||||
rdma_cfg = RdmaBackendConfig(
|
rdma_cfg = RdmaBackendConfig(
|
||||||
qp_per_transfer,
|
qp_per_transfer,
|
||||||
@ -726,7 +725,7 @@ class MoRIIOWrapper:
|
|||||||
|
|
||||||
return
|
return
|
||||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||||
logger.debug("Failed to decode msgpack message, will try as string. Error: %s")
|
logger.debug("Failed to decode msgpack message, will try as string")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -809,9 +808,9 @@ class MoRIIOWrapper:
|
|||||||
for path, sock in self.paths.items():
|
for path, sock in self.paths.items():
|
||||||
try:
|
try:
|
||||||
sock.close(linger=0)
|
sock.close(linger=0)
|
||||||
logger.debug(f"Closed ZMQ socket for path: {path}")
|
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||||
self.paths.clear()
|
self.paths.clear()
|
||||||
|
|
||||||
|
|
||||||
@ -1292,6 +1291,15 @@ class MoRIIOConnectorScheduler:
|
|||||||
|
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
for path, sock in self.paths.items():
|
||||||
|
try:
|
||||||
|
sock.close(linger=0)
|
||||||
|
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||||
|
self.paths.clear()
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
@ -1525,7 +1533,7 @@ class MoRIIOConnectorWorker:
|
|||||||
use_mla=self.use_mla,
|
use_mla=self.use_mla,
|
||||||
)
|
)
|
||||||
|
|
||||||
#TODO: consider the integration of flashinfer or other backends.
|
# TODO: consider the integration of flashinfer or other backends.
|
||||||
self.backend_name = backend.get_name()
|
self.backend_name = backend.get_name()
|
||||||
logger.debug("Detected attention backend %s", self.backend_name)
|
logger.debug("Detected attention backend %s", self.backend_name)
|
||||||
|
|
||||||
@ -1654,17 +1662,9 @@ class MoRIIOConnectorWorker:
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper:
|
if hasattr(self, "moriio_wrapper") and self.moriio_wrapper:
|
||||||
self.moriio_wrapper.shutdown()
|
self.moriio_wrapper.shutdown()
|
||||||
|
|
||||||
for path, sock in self.paths.items():
|
|
||||||
try:
|
|
||||||
sock.close(linger=0)
|
|
||||||
logger.debug(f"Closed ZMQ socket for path: {path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
|
||||||
self.paths.clear()
|
|
||||||
|
|
||||||
if hasattr(self, "_handshake_initiation_executor"):
|
if hasattr(self, "_handshake_initiation_executor"):
|
||||||
self._handshake_initiation_executor.shutdown(wait=False)
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
|
|
||||||
@ -1902,11 +1902,7 @@ class MoRIIOConnectorWorker:
|
|||||||
caches_data = []
|
caches_data = []
|
||||||
|
|
||||||
for cache_or_caches in kv_caches.values():
|
for cache_or_caches in kv_caches.values():
|
||||||
cache_list = (
|
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
||||||
[cache_or_caches]
|
|
||||||
if use_mla
|
|
||||||
else cache_or_caches
|
|
||||||
)
|
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len
|
region_len = self.num_blocks * self.block_len
|
||||||
|
|||||||
17
vllm/envs.py
17
vllm/envs.py
@ -193,10 +193,10 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||||
VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False
|
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
|
||||||
VLLM_MORIIO_QP_PER_TRANSFER:int=1
|
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
|
||||||
VLLM_MORIIO_POST_BATCH_SIZE:int=-1
|
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
|
||||||
VLLM_MORIIO_NUM_WORKERS:int=1
|
VLLM_MORIIO_NUM_WORKERS: int = 1
|
||||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||||
@ -1349,23 +1349,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
),
|
),
|
||||||
# Controls the read mode for the Mori-IO connector
|
# Controls the read mode for the Mori-IO connector
|
||||||
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
|
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
|
||||||
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower()
|
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
|
||||||
in ("true", "1")
|
|
||||||
),
|
),
|
||||||
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
|
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
|
||||||
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
|
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
|
||||||
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
|
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
|
||||||
),
|
),
|
||||||
|
|
||||||
# Controls the post-processing batch size for the Mori-IO connector
|
# Controls the post-processing batch size for the Mori-IO connector
|
||||||
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
|
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
|
||||||
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
|
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
|
||||||
),
|
),
|
||||||
|
|
||||||
# Controls the number of workers for Mori operations for the Mori-IO connector
|
# Controls the number of workers for Mori operations for the Mori-IO connector
|
||||||
"VLLM_MORIIO_NUM_WORKERS": lambda: int(
|
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
|
||||||
os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")
|
|
||||||
),
|
|
||||||
# Controls whether or not to use cudnn prefill
|
# Controls whether or not to use cudnn prefill
|
||||||
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user